Răsfoiți Sursa

Merge branch 'master' into patch-1

zhayujie 2 ani în urmă
părinte
comite
654ebe93e7

+ 3 - 4
README.md

@@ -1,13 +1,13 @@
 # 简介
 # 简介
 
 
-> 本项目是基于大模型的智能对话机器人,支持微信、企业微信、公众号、飞书、钉钉接入,可选择GPT3.5/GPT4.0/Claude/文心一言/讯飞星火/通义千问/Gemini/LinkAI,能处理文本、语音和图片,通过插件访问操作系统和互联网等外部资源,支持基于自有知识库定制企业AI应用。
+> 本项目是基于大模型的智能对话机器人,支持微信、企业微信、公众号、飞书、钉钉接入,可选择GPT3.5/GPT4.0/Claude/文心一言/讯飞星火/通义千问/Gemini/LinkAI/ZhipuAI,能处理文本、语音和图片,通过插件访问操作系统和互联网等外部资源,支持基于自有知识库定制企业AI应用。
 
 
 最新版本支持的功能如下:
 最新版本支持的功能如下:
 
 
 - [x] **多端部署:** 有多种部署方式可选择且功能完备,目前已支持个人微信、微信公众号和、企业微信、飞书、钉钉等部署方式
 - [x] **多端部署:** 有多种部署方式可选择且功能完备,目前已支持个人微信、微信公众号和、企业微信、飞书、钉钉等部署方式
-- [x] **基础对话:** 私聊及群聊的消息智能回复,支持多轮会话上下文记忆,支持 GPT-3.5, GPT-4, claude, Gemini, 文心一言, 讯飞星火, 通义千问
+- [x] **基础对话:** 私聊及群聊的消息智能回复,支持多轮会话上下文记忆,支持 GPT-3.5, GPT-4, claude, Gemini, 文心一言, 讯飞星火, 通义千问,ChatGLM
 - [x] **语音能力:** 可识别语音消息,通过文字或语音回复,支持 azure, baidu, google, openai(whisper/tts) 等多种语音模型
 - [x] **语音能力:** 可识别语音消息,通过文字或语音回复,支持 azure, baidu, google, openai(whisper/tts) 等多种语音模型
-- [x] **图像能力:** 支持图片生成、图片识别、图生图(如照片修复),可选择 Dall-E-3, stable diffusion, replicate, midjourney, vision模型
+- [x] **图像能力:** 支持图片生成、图片识别、图生图(如照片修复),可选择 Dall-E-3, stable diffusion, replicate, midjourney, CogView-3, vision模型
 - [x] **丰富插件:** 支持个性化插件扩展,已实现多角色切换、文字冒险、敏感词过滤、聊天记录总结、文档总结和对话、联网搜索等插件
 - [x] **丰富插件:** 支持个性化插件扩展,已实现多角色切换、文字冒险、敏感词过滤、聊天记录总结、文档总结和对话、联网搜索等插件
 - [x] **知识库:** 通过上传知识库文件自定义专属机器人,可作为数字分身、智能客服、私域助手使用,基于 [LinkAI](https://link-ai.tech) 实现
 - [x] **知识库:** 通过上传知识库文件自定义专属机器人,可作为数字分身、智能客服、私域助手使用,基于 [LinkAI](https://link-ai.tech) 实现
 
 
@@ -23,7 +23,6 @@ Demo made by [Visionn](https://www.wangpc.cc/)
 SaaS服务、私有化部署、稳定托管接入 等多种模式。
 SaaS服务、私有化部署、稳定托管接入 等多种模式。
 >
 >
 > 目前已在私域运营、智能客服、企业效率助手等场景积累了丰富的 AI 解决方案, 在电商、文教、健康、新消费等各行业沉淀了 AI 落地的最佳实践,致力于打造助力中小企业拥抱 AI 的一站式平台。
 > 目前已在私域运营、智能客服、企业效率助手等场景积累了丰富的 AI 解决方案, 在电商、文教、健康、新消费等各行业沉淀了 AI 落地的最佳实践,致力于打造助力中小企业拥抱 AI 的一站式平台。
-
 企业服务和商用咨询可联系产品顾问:
 企业服务和商用咨询可联系产品顾问:
 
 
 <img width="240" src="https://img-1317903499.cos.ap-guangzhou.myqcloud.com/docs/product-manager-qrcode.jpg">
 <img width="240" src="https://img-1317903499.cos.ap-guangzhou.myqcloud.com/docs/product-manager-qrcode.jpg">

+ 19 - 14
app.py

@@ -3,6 +3,7 @@
 import os
 import os
 import signal
 import signal
 import sys
 import sys
+import time
 
 
 from channel import channel_factory
 from channel import channel_factory
 from common import const
 from common import const
@@ -24,6 +25,21 @@ def sigterm_handler_wrap(_signo):
     signal.signal(_signo, func)
     signal.signal(_signo, func)
 
 
 
 
+def start_channel(channel_name: str):
+    channel = channel_factory.create_channel(channel_name)
+    if channel_name in ["wx", "wxy", "terminal", "wechatmp", "wechatmp_service", "wechatcom_app", "wework",
+                        const.FEISHU, const.DINGTALK]:
+        PluginManager().load_plugins()
+
+    if conf().get("use_linkai"):
+        try:
+            from common import linkai_client
+            threading.Thread(target=linkai_client.start, args=(channel,)).start()
+        except Exception as e:
+            pass
+    channel.startup()
+
+
 def run():
 def run():
     try:
     try:
         # load config
         # load config
@@ -41,22 +57,11 @@ def run():
 
 
         if channel_name == "wxy":
         if channel_name == "wxy":
             os.environ["WECHATY_LOG"] = "warn"
             os.environ["WECHATY_LOG"] = "warn"
-            # os.environ['WECHATY_PUPPET_SERVICE_ENDPOINT'] = '127.0.0.1:9001'
-
-        channel = channel_factory.create_channel(channel_name)
-        if channel_name in ["wx", "wxy", "terminal", "wechatmp", "wechatmp_service", "wechatcom_app", "wework", const.FEISHU,const.DINGTALK]:
-            PluginManager().load_plugins()
-
-        if conf().get("use_linkai"):
-            try:
-                from common import linkai_client
-                threading.Thread(target=linkai_client.start, args=(channel, )).start()
-            except Exception as e:
-                pass
 
 
-        # startup channel
-        channel.startup()
+        start_channel(channel_name)
 
 
+        while True:
+            time.sleep(1)
     except Exception as e:
     except Exception as e:
         logger.error("App startup failed!")
         logger.error("App startup failed!")
         logger.exception(e)
         logger.exception(e)

+ 5 - 0
bot/bot_factory.py

@@ -52,4 +52,9 @@ def create_bot(bot_type):
         from bot.gemini.google_gemini_bot import GoogleGeminiBot
         from bot.gemini.google_gemini_bot import GoogleGeminiBot
         return GoogleGeminiBot()
         return GoogleGeminiBot()
 
 
+    elif bot_type == const.ZHIPU_AI:
+        from bot.zhipuai.zhipuai_bot import ZHIPUAIBot
+        return ZHIPUAIBot()
+
+
     raise RuntimeError
     raise RuntimeError

+ 3 - 0
bot/gemini/google_gemini_bot.py

@@ -44,6 +44,7 @@ class GoogleGeminiBot(Bot):
         except Exception as e:
         except Exception as e:
             logger.error("[Gemini] fetch reply error, may contain unsafe content")
             logger.error("[Gemini] fetch reply error, may contain unsafe content")
             logger.error(e)
             logger.error(e)
+            return Reply(ReplyType.ERROR, "invoke [Gemini] api failed!")
 
 
     def _convert_to_gemini_messages(self, messages: list):
     def _convert_to_gemini_messages(self, messages: list):
         res = []
         res = []
@@ -63,6 +64,8 @@ class GoogleGeminiBot(Bot):
     def _filter_messages(self, messages: list):
     def _filter_messages(self, messages: list):
         res = []
         res = []
         turn = "user"
         turn = "user"
+        if not messages:
+            return res
         for i in range(len(messages) - 1, -1, -1):
         for i in range(len(messages) - 1, -1, -1):
             message = messages[i]
             message = messages[i]
             if message.get("role") != turn:
             if message.get("role") != turn:

+ 3 - 2
bot/linkai/link_ai_bot.py

@@ -92,7 +92,8 @@ class LinkAIBot(Bot):
                 "frequency_penalty": conf().get("frequency_penalty", 0.0),  # [-2,2]之间,该值越大则更倾向于产生不同的内容
                 "frequency_penalty": conf().get("frequency_penalty", 0.0),  # [-2,2]之间,该值越大则更倾向于产生不同的内容
                 "presence_penalty": conf().get("presence_penalty", 0.0),  # [-2,2]之间,该值越大则更倾向于产生不同的内容
                 "presence_penalty": conf().get("presence_penalty", 0.0),  # [-2,2]之间,该值越大则更倾向于产生不同的内容
                 "session_id": session_id,
                 "session_id": session_id,
-                "channel_type": conf().get("channel_type")
+                "sender_id": session_id,
+                "channel_type": conf().get("channel_type", "wx")
             }
             }
             try:
             try:
                 from linkai import LinkAIClient
                 from linkai import LinkAIClient
@@ -400,7 +401,7 @@ class LinkAIBot(Bot):
                 i += 1
                 i += 1
                 if url.endswith(".mp4"):
                 if url.endswith(".mp4"):
                     reply_type = ReplyType.VIDEO_URL
                     reply_type = ReplyType.VIDEO_URL
-                elif url.endswith(".pdf") or url.endswith(".doc") or url.endswith(".docx"):
+                elif url.endswith(".pdf") or url.endswith(".doc") or url.endswith(".docx") or url.endswith(".csv"):
                     reply_type = ReplyType.FILE
                     reply_type = ReplyType.FILE
                     url = _download_file(url)
                     url = _download_file(url)
                     if not url:
                     if not url:

+ 3 - 2
bot/xunfei/xunfei_spark_bot.py

@@ -46,8 +46,9 @@ class XunFeiBot(Bot):
         self.domain = "generalv3"
         self.domain = "generalv3"
         # 默认使用v2.0版本: "ws://spark-api.xf-yun.com/v2.1/chat"
         # 默认使用v2.0版本: "ws://spark-api.xf-yun.com/v2.1/chat"
         # v1.5版本为: "ws://spark-api.xf-yun.com/v1.1/chat"
         # v1.5版本为: "ws://spark-api.xf-yun.com/v1.1/chat"
-        # v3.5版本为: "ws://spark-api.xf-yun.com/v3.5/chat"
-        self.spark_url = "ws://spark-api.xf-yun.com/v3.5/chat"
+        # v3.0版本为: "ws://spark-api.xf-yun.com/v3.1/chat"
+        # v3.5版本为: "wss://spark-api.xf-yun.com/v3.5/chat"
+        self.spark_url = "wss://spark-api.xf-yun.com/v3.5/chat"
         self.host = urlparse(self.spark_url).netloc
         self.host = urlparse(self.spark_url).netloc
         self.path = urlparse(self.spark_url).path
         self.path = urlparse(self.spark_url).path
         # 和wenxin使用相同的session机制
         # 和wenxin使用相同的session机制

+ 29 - 0
bot/zhipuai/zhipu_ai_image.py

@@ -0,0 +1,29 @@
+from common.log import logger
+from config import conf
+
+
+# ZhipuAI提供的画图接口
+
+class ZhipuAIImage(object):
+    def __init__(self):
+        from zhipuai import ZhipuAI
+        self.client = ZhipuAI(api_key=conf().get("zhipu_ai_api_key"))
+
+    def create_img(self, query, retry_count=0, api_key=None, api_base=None):
+        try:
+            if conf().get("rate_limit_dalle"):
+                return False, "请求太快了,请休息一下再问我吧"
+            logger.info("[ZHIPU_AI] image_query={}".format(query))
+            response = self.client.images.generations(
+                prompt=query,
+                n=1,  # 每次生成图片的数量
+                model=conf().get("text_to_image") or "cogview-3",
+                size=conf().get("image_create_size", "1024x1024"),  # 图片大小,可选有 256x256, 512x512, 1024x1024
+                quality="standard",
+            )
+            image_url = response.data[0].url
+            logger.info("[ZHIPU_AI] image_url={}".format(image_url))
+            return True, image_url
+        except Exception as e:
+            logger.exception(e)
+            return False, "画图出现问题,请休息一下再问我吧"

+ 51 - 0
bot/zhipuai/zhipu_ai_session.py

@@ -0,0 +1,51 @@
+from bot.session_manager import Session
+from common.log import logger
+
+
+class ZhipuAISession(Session):
+    def __init__(self, session_id, system_prompt=None, model="glm-4"):
+        super().__init__(session_id, system_prompt)
+        self.model = model
+        self.reset()
+
+    def discard_exceeding(self, max_tokens, cur_tokens=None):
+        precise = True
+        try:
+            cur_tokens = self.calc_tokens()
+        except Exception as e:
+            precise = False
+            if cur_tokens is None:
+                raise e
+            logger.debug("Exception when counting tokens precisely for query: {}".format(e))
+        while cur_tokens > max_tokens:
+            if len(self.messages) > 2:
+                self.messages.pop(1)
+            elif len(self.messages) == 2 and self.messages[1]["role"] == "assistant":
+                self.messages.pop(1)
+                if precise:
+                    cur_tokens = self.calc_tokens()
+                else:
+                    cur_tokens = cur_tokens - max_tokens
+                break
+            elif len(self.messages) == 2 and self.messages[1]["role"] == "user":
+                logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens))
+                break
+            else:
+                logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens,
+                                                                                       len(self.messages)))
+                break
+            if precise:
+                cur_tokens = self.calc_tokens()
+            else:
+                cur_tokens = cur_tokens - max_tokens
+        return cur_tokens
+
+    def calc_tokens(self):
+        return num_tokens_from_messages(self.messages, self.model)
+
+
+def num_tokens_from_messages(messages, model):
+    tokens = 0
+    for msg in messages:
+        tokens += len(msg["content"])
+    return tokens

+ 149 - 0
bot/zhipuai/zhipuai_bot.py

@@ -0,0 +1,149 @@
+# encoding:utf-8
+
+import time
+
+import openai
+import openai.error
+from bot.bot import Bot
+from bot.zhipuai.zhipu_ai_session import ZhipuAISession
+from bot.zhipuai.zhipu_ai_image import ZhipuAIImage
+from bot.session_manager import SessionManager
+from bridge.context import ContextType
+from bridge.reply import Reply, ReplyType
+from common.log import logger
+from config import conf, load_config
+from zhipuai import ZhipuAI
+
+
+# ZhipuAI对话模型API
+class ZHIPUAIBot(Bot, ZhipuAIImage):
+    def __init__(self):
+        super().__init__()
+        self.sessions = SessionManager(ZhipuAISession, model=conf().get("model") or "ZHIPU_AI")
+        self.args = {
+            "model": conf().get("model") or "glm-4",  # 对话模型的名称
+            "temperature": conf().get("temperature", 0.9),  # 值在(0,1)之间(智谱AI 的温度不能取 0 或者 1)
+            "top_p": conf().get("top_p", 0.7),  # 值在(0,1)之间(智谱AI 的 top_p 不能取 0 或者 1)
+        }
+        self.client = ZhipuAI(api_key=conf().get("zhipu_ai_api_key"))
+
+    def reply(self, query, context=None):
+        # acquire reply content
+        if context.type == ContextType.TEXT:
+            logger.info("[ZHIPU_AI] query={}".format(query))
+
+            session_id = context["session_id"]
+            reply = None
+            clear_memory_commands = conf().get("clear_memory_commands", ["#清除记忆"])
+            if query in clear_memory_commands:
+                self.sessions.clear_session(session_id)
+                reply = Reply(ReplyType.INFO, "记忆已清除")
+            elif query == "#清除所有":
+                self.sessions.clear_all_session()
+                reply = Reply(ReplyType.INFO, "所有人记忆已清除")
+            elif query == "#更新配置":
+                load_config()
+                reply = Reply(ReplyType.INFO, "配置已更新")
+            if reply:
+                return reply
+            session = self.sessions.session_query(query, session_id)
+            logger.debug("[ZHIPU_AI] session query={}".format(session.messages))
+
+            api_key = context.get("openai_api_key") or openai.api_key
+            model = context.get("gpt_model")
+            new_args = None
+            if model:
+                new_args = self.args.copy()
+                new_args["model"] = model
+            # if context.get('stream'):
+            #     # reply in stream
+            #     return self.reply_text_stream(query, new_query, session_id)
+
+            reply_content = self.reply_text(session, api_key, args=new_args)
+            logger.debug(
+                "[ZHIPU_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
+                    session.messages,
+                    session_id,
+                    reply_content["content"],
+                    reply_content["completion_tokens"],
+                )
+            )
+            if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0:
+                reply = Reply(ReplyType.ERROR, reply_content["content"])
+            elif reply_content["completion_tokens"] > 0:
+                self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"])
+                reply = Reply(ReplyType.TEXT, reply_content["content"])
+            else:
+                reply = Reply(ReplyType.ERROR, reply_content["content"])
+                logger.debug("[ZHIPU_AI] reply {} used 0 tokens.".format(reply_content))
+            return reply
+        elif context.type == ContextType.IMAGE_CREATE:
+            ok, retstring = self.create_img(query, 0)
+            reply = None
+            if ok:
+                reply = Reply(ReplyType.IMAGE_URL, retstring)
+            else:
+                reply = Reply(ReplyType.ERROR, retstring)
+            return reply
+
+        else:
+            reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
+            return reply
+
+    def reply_text(self, session: ZhipuAISession, api_key=None, args=None, retry_count=0) -> dict:
+        """
+        call openai's ChatCompletion to get the answer
+        :param session: a conversation session
+        :param session_id: session id
+        :param retry_count: retry count
+        :return: {}
+        """
+        try:
+            # if conf().get("rate_limit_chatgpt") and not self.tb4chatgpt.get_token():
+            #     raise openai.error.RateLimitError("RateLimitError: rate limit exceeded")
+            # if api_key == None, the default openai.api_key will be used
+            if args is None:
+                args = self.args
+            # response = openai.ChatCompletion.create(api_key=api_key, messages=session.messages, **args)
+            response = self.client.chat.completions.create(messages=session.messages, **args)
+            # logger.debug("[ZHIPU_AI] response={}".format(response))
+            # logger.info("[ZHIPU_AI] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
+
+            return {
+                "total_tokens": response.usage.total_tokens,
+                "completion_tokens": response.usage.completion_tokens,
+                "content": response.choices[0].message.content,
+            }
+        except Exception as e:
+            need_retry = retry_count < 2
+            result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
+            if isinstance(e, openai.error.RateLimitError):
+                logger.warn("[ZHIPU_AI] RateLimitError: {}".format(e))
+                result["content"] = "提问太快啦,请休息一下再问我吧"
+                if need_retry:
+                    time.sleep(20)
+            elif isinstance(e, openai.error.Timeout):
+                logger.warn("[ZHIPU_AI] Timeout: {}".format(e))
+                result["content"] = "我没有收到你的消息"
+                if need_retry:
+                    time.sleep(5)
+            elif isinstance(e, openai.error.APIError):
+                logger.warn("[ZHIPU_AI] Bad Gateway: {}".format(e))
+                result["content"] = "请再问我一次"
+                if need_retry:
+                    time.sleep(10)
+            elif isinstance(e, openai.error.APIConnectionError):
+                logger.warn("[ZHIPU_AI] APIConnectionError: {}".format(e))
+                result["content"] = "我连接不到你的网络"
+                if need_retry:
+                    time.sleep(5)
+            else:
+                logger.exception("[ZHIPU_AI] Exception: {}".format(e), e)
+                need_retry = False
+                self.sessions.clear_session(session.session_id)
+
+            if need_retry:
+                logger.warn("[ZHIPU_AI] 第{}次重试".format(retry_count + 1))
+                return self.reply_text(session, api_key, args, retry_count + 1)
+            else:
+                return result

+ 2 - 0
bridge/bridge.py

@@ -31,6 +31,8 @@ class Bridge(object):
             self.btype["chat"] = const.QWEN
             self.btype["chat"] = const.QWEN
         if model_type in [const.GEMINI]:
         if model_type in [const.GEMINI]:
             self.btype["chat"] = const.GEMINI
             self.btype["chat"] = const.GEMINI
+        if model_type in [const.ZHIPU_AI]:
+            self.btype["chat"] = const.ZHIPU_AI
 
 
         if conf().get("use_linkai") and conf().get("linkai_api_key"):
         if conf().get("use_linkai") and conf().get("linkai_api_key"):
             self.btype["chat"] = const.LINKAI
             self.btype["chat"] = const.LINKAI

+ 1 - 1
bridge/reply.py

@@ -11,7 +11,7 @@ class ReplyType(Enum):
     VIDEO_URL = 5  # 视频URL
     VIDEO_URL = 5  # 视频URL
     FILE = 6  # 文件
     FILE = 6  # 文件
     CARD = 7  # 微信名片,仅支持ntchat
     CARD = 7  # 微信名片,仅支持ntchat
-    InviteRoom = 8  # 邀请好友进群
+    INVITE_ROOM = 8  # 邀请好友进群
     INFO = 9
     INFO = 9
     ERROR = 10
     ERROR = 10
     TEXT_ = 11  # 强制文本
     TEXT_ = 11  # 强制文本

+ 9 - 5
channel/chat_channel.py

@@ -4,6 +4,7 @@ import threading
 import time
 import time
 from asyncio import CancelledError
 from asyncio import CancelledError
 from concurrent.futures import Future, ThreadPoolExecutor
 from concurrent.futures import Future, ThreadPoolExecutor
+from concurrent import futures
 
 
 from bridge.context import *
 from bridge.context import *
 from bridge.reply import *
 from bridge.reply import *
@@ -17,6 +18,8 @@ try:
 except Exception as e:
 except Exception as e:
     pass
     pass
 
 
+handler_pool = ThreadPoolExecutor(max_workers=8)  # 处理消息的线程池
+
 
 
 # 抽象类, 它包含了与消息通道无关的通用处理逻辑
 # 抽象类, 它包含了与消息通道无关的通用处理逻辑
 class ChatChannel(Channel):
 class ChatChannel(Channel):
@@ -25,7 +28,6 @@ class ChatChannel(Channel):
     futures = {}  # 记录每个session_id提交到线程池的future对象, 用于重置会话时把没执行的future取消掉,正在执行的不会被取消
     futures = {}  # 记录每个session_id提交到线程池的future对象, 用于重置会话时把没执行的future取消掉,正在执行的不会被取消
     sessions = {}  # 用于控制并发,每个session_id同时只能有一个context在处理
     sessions = {}  # 用于控制并发,每个session_id同时只能有一个context在处理
     lock = threading.Lock()  # 用于控制对sessions的访问
     lock = threading.Lock()  # 用于控制对sessions的访问
-    handler_pool = ThreadPoolExecutor(max_workers=8)  # 处理消息的线程池
 
 
     def __init__(self):
     def __init__(self):
         _thread = threading.Thread(target=self.consume)
         _thread = threading.Thread(target=self.consume)
@@ -168,11 +170,13 @@ class ChatChannel(Channel):
         reply = self._generate_reply(context)
         reply = self._generate_reply(context)
 
 
         logger.debug("[WX] ready to decorate reply: {}".format(reply))
         logger.debug("[WX] ready to decorate reply: {}".format(reply))
+
         # reply的包装步骤
         # reply的包装步骤
-        reply = self._decorate_reply(context, reply)
+        if reply and reply.content:
+            reply = self._decorate_reply(context, reply)
 
 
-        # reply的发送步骤
-        self._send_reply(context, reply)
+            # reply的发送步骤
+            self._send_reply(context, reply)
 
 
     def _generate_reply(self, context: Context, reply: Reply = Reply()) -> Reply:
     def _generate_reply(self, context: Context, reply: Reply = Reply()) -> Reply:
         e_context = PluginManager().emit_event(
         e_context = PluginManager().emit_event(
@@ -339,7 +343,7 @@ class ChatChannel(Channel):
                         if not context_queue.empty():
                         if not context_queue.empty():
                             context = context_queue.get()
                             context = context_queue.get()
                             logger.debug("[WX] consume context: {}".format(context))
                             logger.debug("[WX] consume context: {}".format(context))
-                            future: Future = self.handler_pool.submit(self._handle, context)
+                            future: Future = handler_pool.submit(self._handle, context)
                             future.add_done_callback(self._thread_pool_callback(session_id, context=context))
                             future.add_done_callback(self._thread_pool_callback(session_id, context=context))
                             if session_id not in self.futures:
                             if session_id not in self.futures:
                                 self.futures[session_id] = []
                                 self.futures[session_id] = []

+ 32 - 25
channel/wechat/wechat_channel.py

@@ -15,6 +15,7 @@ import requests
 from bridge.context import *
 from bridge.context import *
 from bridge.reply import *
 from bridge.reply import *
 from channel.chat_channel import ChatChannel
 from channel.chat_channel import ChatChannel
+from channel import chat_channel
 from channel.wechat.wechat_message import *
 from channel.wechat.wechat_message import *
 from common.expired_dict import ExpiredDict
 from common.expired_dict import ExpiredDict
 from common.log import logger
 from common.log import logger
@@ -112,30 +113,39 @@ class WechatChannel(ChatChannel):
         self.auto_login_times = 0
         self.auto_login_times = 0
 
 
     def startup(self):
     def startup(self):
-        itchat.instance.receivingRetryCount = 600  # 修改断线超时时间
-        # login by scan QRCode
-        hotReload = conf().get("hot_reload", False)
-        status_path = os.path.join(get_appdata_dir(), "itchat.pkl")
-        itchat.auto_login(
-            enableCmdQR=2,
-            hotReload=hotReload,
-            statusStorageDir=status_path,
-            qrCallback=qrCallback,
-            exitCallback=self.exitCallback,
-            loginCallback=self.loginCallback
-        )
-        self.user_id = itchat.instance.storageClass.userName
-        self.name = itchat.instance.storageClass.nickName
-        logger.info("Wechat login success, user_id: {}, nickname: {}".format(self.user_id, self.name))
-        # start message listener
-        itchat.run()
+        try:
+            itchat.instance.receivingRetryCount = 600  # 修改断线超时时间
+            # login by scan QRCode
+            hotReload = conf().get("hot_reload", False)
+            status_path = os.path.join(get_appdata_dir(), "itchat.pkl")
+            itchat.auto_login(
+                enableCmdQR=2,
+                hotReload=hotReload,
+                statusStorageDir=status_path,
+                qrCallback=qrCallback,
+                exitCallback=self.exitCallback,
+                loginCallback=self.loginCallback
+            )
+            self.user_id = itchat.instance.storageClass.userName
+            self.name = itchat.instance.storageClass.nickName
+            logger.info("Wechat login success, user_id: {}, nickname: {}".format(self.user_id, self.name))
+            # start message listener
+            itchat.run()
+        except Exception as e:
+            logger.error(e)
 
 
     def exitCallback(self):
     def exitCallback(self):
-        _send_logout()
-        time.sleep(3)
-        self.auto_login_times += 1
-        if self.auto_login_times < 100:
-            self.startup()
+        try:
+            from common.linkai_client import chat_client
+            if chat_client.client_id and conf().get("use_linkai"):
+                _send_logout()
+                time.sleep(2)
+                self.auto_login_times += 1
+                if self.auto_login_times < 100:
+                    chat_channel.handler_pool._shutdown = False
+                    self.startup()
+        except Exception as e:
+            pass
 
 
     def loginCallback(self):
     def loginCallback(self):
         logger.debug("Login success")
         logger.debug("Login success")
@@ -223,7 +233,6 @@ class WechatChannel(ChatChannel):
             logger.info("[WX] sendImage url={}, receiver={}".format(img_url, receiver))
             logger.info("[WX] sendImage url={}, receiver={}".format(img_url, receiver))
         elif reply.type == ReplyType.IMAGE:  # 从文件读取图片
         elif reply.type == ReplyType.IMAGE:  # 从文件读取图片
             image_storage = reply.content
             image_storage = reply.content
-            image_storage.seek(0)
             itchat.send_image(image_storage, toUserName=receiver)
             itchat.send_image(image_storage, toUserName=receiver)
             logger.info("[WX] sendImage, receiver={}".format(receiver))
             logger.info("[WX] sendImage, receiver={}".format(receiver))
         elif reply.type == ReplyType.FILE:  # 新增文件回复类型
         elif reply.type == ReplyType.FILE:  # 新增文件回复类型
@@ -259,7 +268,6 @@ def _send_login_success():
 def _send_logout():
 def _send_logout():
     try:
     try:
         from common.linkai_client import chat_client
         from common.linkai_client import chat_client
-        time.sleep(2)
         if chat_client.client_id:
         if chat_client.client_id:
             chat_client.send_logout()
             chat_client.send_logout()
     except Exception as e:
     except Exception as e:
@@ -268,7 +276,6 @@ def _send_logout():
 def _send_qr_code(qrcode_list: list):
 def _send_qr_code(qrcode_list: list):
     try:
     try:
         from common.linkai_client import chat_client
         from common.linkai_client import chat_client
-        time.sleep(2)
         if chat_client.client_id:
         if chat_client.client_id:
             chat_client.send_qrcode(qrcode_list)
             chat_client.send_qrcode(qrcode_list)
     except Exception as e:
     except Exception as e:

+ 3 - 1
common/const.py

@@ -8,6 +8,8 @@ LINKAI = "linkai"
 CLAUDEAI = "claude"
 CLAUDEAI = "claude"
 QWEN = "qwen"
 QWEN = "qwen"
 GEMINI = "gemini"
 GEMINI = "gemini"
+ZHIPU_AI = "glm-4"
+
 
 
 # model
 # model
 GPT35 = "gpt-3.5-turbo"
 GPT35 = "gpt-3.5-turbo"
@@ -19,7 +21,7 @@ TTS_1 = "tts-1"
 TTS_1_HD = "tts-1-hd"
 TTS_1_HD = "tts-1-hd"
 
 
 MODEL_LIST = ["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "wenxin", "wenxin-4", "xunfei", "claude", "gpt-4-turbo",
 MODEL_LIST = ["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "wenxin", "wenxin-4", "xunfei", "claude", "gpt-4-turbo",
-              "gpt-4-turbo-preview", "gpt-4-1106-preview", GPT4_TURBO_PREVIEW, QWEN, GEMINI]
+              "gpt-4-turbo-preview", "gpt-4-1106-preview", GPT4_TURBO_PREVIEW, QWEN, GEMINI, ZHIPU_AI]
 
 
 # channel
 # channel
 FEISHU = "feishu"
 FEISHU = "feishu"

+ 26 - 1
common/linkai_client.py

@@ -2,7 +2,9 @@ from bridge.context import Context, ContextType
 from bridge.reply import Reply, ReplyType
 from bridge.reply import Reply, ReplyType
 from common.log import logger
 from common.log import logger
 from linkai import LinkAIClient, PushMsg
 from linkai import LinkAIClient, PushMsg
-from config import conf
+from config import conf, pconf, plugin_config
+from plugins import PluginManager
+
 
 
 chat_client: LinkAIClient
 chat_client: LinkAIClient
 
 
@@ -22,6 +24,29 @@ class ChatClient(LinkAIClient):
         context["isgroup"] = push_msg.is_group
         context["isgroup"] = push_msg.is_group
         self.channel.send(Reply(ReplyType.TEXT, content=msg_content), context)
         self.channel.send(Reply(ReplyType.TEXT, content=msg_content), context)
 
 
+    def on_config(self, config: dict):
+        if not self.client_id:
+            return
+        logger.info(f"从控制台加载配置: {config}")
+        local_config = conf()
+        for key in local_config.keys():
+            if config.get(key) is not None:
+                local_config[key] = config.get(key)
+        if config.get("reply_voice_mode"):
+            if config.get("reply_voice_mode") == "voice_reply_voice":
+                local_config["voice_reply_voice"] = True
+            elif config.get("reply_voice_mode") == "always_reply_voice":
+                local_config["always_reply_voice"] = True
+        # if config.get("admin_password") and plugin_config["Godcmd"]:
+        #     plugin_config["Godcmd"]["password"] = config.get("admin_password")
+        #     PluginManager().instances["Godcmd"].reload()
+        # if config.get("group_app_map") and pconf("linkai"):
+        #     local_group_map = {}
+        #     for mapping in config.get("group_app_map"):
+        #         local_group_map[mapping.get("group_name")] = mapping.get("app_code")
+        #     pconf("linkai")["group_app_map"] = local_group_map
+        #     PluginManager().instances["linkai"].reload()
+
 
 
 def start(channel):
 def start(channel):
     global chat_client
     global chat_client

+ 4 - 1
config.py

@@ -83,7 +83,7 @@ available_setting = {
     "voice_reply_voice": False,  # 是否使用语音回复语音,需要设置对应语音合成引擎的api key
     "voice_reply_voice": False,  # 是否使用语音回复语音,需要设置对应语音合成引擎的api key
     "always_reply_voice": False,  # 是否一直使用语音回复
     "always_reply_voice": False,  # 是否一直使用语音回复
     "voice_to_text": "openai",  # 语音识别引擎,支持openai,baidu,google,azure
     "voice_to_text": "openai",  # 语音识别引擎,支持openai,baidu,google,azure
-    "text_to_voice": "openai",  # 语音合成引擎,支持openai,baidu,google,pytts(offline),azure,elevenlabs
+    "text_to_voice": "openai",  # 语音合成引擎,支持openai,baidu,google,pytts(offline),azure,elevenlabs,edge(online)
     "text_to_voice_model": "tts-1",
     "text_to_voice_model": "tts-1",
     "tts_voice_id": "alloy",
     "tts_voice_id": "alloy",
     # baidu 语音api配置, 使用百度语音识别和语音合成时需要
     # baidu 语音api配置, 使用百度语音识别和语音合成时需要
@@ -150,6 +150,9 @@ available_setting = {
     "use_global_plugin_config": False,
     "use_global_plugin_config": False,
     "max_media_send_count": 3,     # 单次最大发送媒体资源的个数
     "max_media_send_count": 3,     # 单次最大发送媒体资源的个数
     "media_send_interval": 1,  # 发送图片的事件间隔,单位秒
     "media_send_interval": 1,  # 发送图片的事件间隔,单位秒
+    # 智谱AI 平台配置
+    "zhipu_ai_api_key": "",
+    "zhipu_ai_api_base": "https://open.bigmodel.cn/api/paas/v4",
     # LinkAI平台配置
     # LinkAI平台配置
     "use_linkai": False,
     "use_linkai": False,
     "linkai_api_key": "",
     "linkai_api_key": "",

+ 8 - 0
plugins/godcmd/godcmd.py

@@ -475,3 +475,11 @@ class Godcmd(Plugin):
         if model == "gpt-4-turbo":
         if model == "gpt-4-turbo":
             return const.GPT4_TURBO_PREVIEW
             return const.GPT4_TURBO_PREVIEW
         return model
         return model
+
+    def reload(self):
+        gconf = plugin_config[self.name]
+        if gconf:
+            if gconf.get("password"):
+                self.password = gconf["password"]
+            if gconf.get("admin_users"):
+                self.admin_users = gconf["admin_users"]

+ 3 - 0
plugins/plugin.py

@@ -46,3 +46,6 @@ class Plugin:
 
 
     def get_help_text(self, **kwargs):
     def get_help_text(self, **kwargs):
         return "暂无帮助信息"
         return "暂无帮助信息"
+
+    def reload(self):
+        pass

+ 16 - 14
plugins/plugin_manager.py

@@ -99,7 +99,7 @@ class PluginManager:
                     try:
                     try:
                         self.current_plugin_path = plugin_path
                         self.current_plugin_path = plugin_path
                         if plugin_path in self.loaded:
                         if plugin_path in self.loaded:
-                            if self.loaded[plugin_path] == None:
+                            if plugin_name.upper() != 'GODCMD':
                                 logger.info("reload module %s" % plugin_name)
                                 logger.info("reload module %s" % plugin_name)
                                 self.loaded[plugin_path] = importlib.reload(sys.modules[import_path])
                                 self.loaded[plugin_path] = importlib.reload(sys.modules[import_path])
                                 dependent_module_names = [name for name in sys.modules.keys() if name.startswith(import_path + ".")]
                                 dependent_module_names = [name for name in sys.modules.keys() if name.startswith(import_path + ".")]
@@ -141,19 +141,21 @@ class PluginManager:
         failed_plugins = []
         failed_plugins = []
         for name, plugincls in self.plugins.items():
         for name, plugincls in self.plugins.items():
             if plugincls.enabled:
             if plugincls.enabled:
-                if name not in self.instances:
-                    try:
-                        instance = plugincls()
-                    except Exception as e:
-                        logger.warn("Failed to init %s, diabled. %s" % (name, e))
-                        self.disable_plugin(name)
-                        failed_plugins.append(name)
-                        continue
-                    self.instances[name] = instance
-                    for event in instance.handlers:
-                        if event not in self.listening_plugins:
-                            self.listening_plugins[event] = []
-                        self.listening_plugins[event].append(name)
+                if 'GODCMD' in self.instances and name == 'GODCMD':
+                    continue
+                # if name not in self.instances:
+                try:
+                    instance = plugincls()
+                except Exception as e:
+                    logger.warn("Failed to init %s, diabled. %s" % (name, e))
+                    self.disable_plugin(name)
+                    failed_plugins.append(name)
+                    continue
+                self.instances[name] = instance
+                for event in instance.handlers:
+                    if event not in self.listening_plugins:
+                        self.listening_plugins[event] = []
+                    self.listening_plugins[event].append(name)
         self.refresh_order()
         self.refresh_order()
         return failed_plugins
         return failed_plugins
 
 

+ 4 - 0
plugins/source.json

@@ -20,5 +20,9 @@
       "url": "https://github.com/6vision/Apilot.git",
       "url": "https://github.com/6vision/Apilot.git",
       "desc": "通过api直接查询早报、热榜、快递、天气等实用信息的插件"
       "desc": "通过api直接查询早报、热榜、快递、天气等实用信息的插件"
     }
     }
+    "pictureChange": {
+      "url": "https://github.com/Yanyutin753/pictureChange.git",
+      "desc": "利用stable-diffusion和百度Ai进行图生图或者画图的插件"
+    }
   }
   }
 }
 }

+ 1 - 1
plugins/tool/tool.py

@@ -137,7 +137,7 @@ class Tool(Plugin):
 
 
         return {
         return {
             # 全局配置相关
             # 全局配置相关
-            "log": True,  # tool 日志开关
+            "log": False,  # tool 日志开关
             "debug": kwargs.get("debug", False),  # 输出更多日志
             "debug": kwargs.get("debug", False),  # 输出更多日志
             "no_default": kwargs.get("no_default", False),  # 不要默认的工具,只加载自己导入的工具
             "no_default": kwargs.get("no_default", False),  # 不要默认的工具,只加载自己导入的工具
             "think_depth": kwargs.get("think_depth", 2),  # 一个问题最多使用多少次工具
             "think_depth": kwargs.get("think_depth", 2),  # 一个问题最多使用多少次工具

+ 5 - 1
requirements-optional.txt

@@ -7,6 +7,7 @@ gTTS>=2.3.1 # google text to speech
 pyttsx3>=2.90 # pytsx text to speech
 pyttsx3>=2.90 # pytsx text to speech
 baidu_aip>=4.16.10 # baidu voice
 baidu_aip>=4.16.10 # baidu voice
 azure-cognitiveservices-speech # azure voice
 azure-cognitiveservices-speech # azure voice
+edge-tts # edge-tts
 numpy<=1.24.2
 numpy<=1.24.2
 langid # language detect
 langid # language detect
 
 
@@ -33,7 +34,10 @@ broadscope_bailian
 google-generativeai
 google-generativeai
 
 
 # linkai
 # linkai
-linkai
+linkai>=0.0.3.5
 
 
 # dingtalk
 # dingtalk
 dingtalk_stream
 dingtalk_stream
+
+# zhipuai
+zhipuai>=2.0.1

+ 3 - 1
voice/audio_convert.py

@@ -64,7 +64,9 @@ def any_to_wav(any_path, wav_path):
     if any_path.endswith(".sil") or any_path.endswith(".silk") or any_path.endswith(".slk"):
     if any_path.endswith(".sil") or any_path.endswith(".silk") or any_path.endswith(".slk"):
         return sil_to_wav(any_path, wav_path)
         return sil_to_wav(any_path, wav_path)
     audio = AudioSegment.from_file(any_path)
     audio = AudioSegment.from_file(any_path)
-    audio.export(wav_path, format="wav")
+    audio.set_frame_rate(8000)    # 百度语音转写支持8000采样率, pcm_s16le, 单通道语音识别
+    audio.set_channels(1)
+    audio.export(wav_path, format="wav", codec='pcm_s16le')
 
 
 
 
 def any_to_sil(any_path, sil_path):
 def any_to_sil(any_path, sil_path):

+ 1 - 1
voice/baidu/baidu_voice.py

@@ -62,7 +62,7 @@ class BaiduVoice(Voice):
         # 识别本地文件
         # 识别本地文件
         logger.debug("[Baidu] voice file name={}".format(voice_file))
         logger.debug("[Baidu] voice file name={}".format(voice_file))
         pcm = get_pcm_from_wav(voice_file)
         pcm = get_pcm_from_wav(voice_file)
-        res = self.client.asr(pcm, "pcm", 16000, {"dev_pid": self.dev_id})
+        res = self.client.asr(pcm, "pcm", 8000, {"dev_pid": self.dev_id})
         if res["err_no"] == 0:
         if res["err_no"] == 0:
             logger.info("百度语音识别到了:{}".format(res["result"]))
             logger.info("百度语音识别到了:{}".format(res["result"]))
             text = "".join(res["result"])
             text = "".join(res["result"])

+ 50 - 0
voice/edge/edge_voice.py

@@ -0,0 +1,50 @@
+import time
+
+import edge_tts
+import asyncio
+
+from bridge.reply import Reply, ReplyType
+from common.log import logger
+from common.tmp_dir import TmpDir
+from voice.voice import Voice
+
+
+class EdgeVoice(Voice):
+
+    def __init__(self):
+        '''
+        # 普通话
+        zh-CN-XiaoxiaoNeural
+        zh-CN-XiaoyiNeural
+        zh-CN-YunjianNeural
+        zh-CN-YunxiNeural
+        zh-CN-YunxiaNeural
+        zh-CN-YunyangNeural
+        # 地方口音
+        zh-CN-liaoning-XiaobeiNeural
+        zh-CN-shaanxi-XiaoniNeural
+        # 粤语
+        zh-HK-HiuGaaiNeural
+        zh-HK-HiuMaanNeural
+        zh-HK-WanLungNeural
+        # 湾湾腔
+        zh-TW-HsiaoChenNeural
+        zh-TW-HsiaoYuNeural
+        zh-TW-YunJheNeural
+        '''
+        self.voice = "zh-CN-YunjianNeural"
+
+    def voiceToText(self, voice_file):
+        pass
+
+    async def gen_voice(self, text, fileName):
+        communicate = edge_tts.Communicate(text, self.voice)
+        await communicate.save(fileName)
+
+    def textToVoice(self, text):
+        fileName = TmpDir().path() + "reply-" + str(int(time.time())) + "-" + str(hash(text) & 0x7FFFFFFF) + ".mp3"
+
+        asyncio.run(self.gen_voice(text, fileName))
+
+        logger.info("[EdgeTTS] textToVoice text={} voice file name={}".format(text, fileName))
+        return Reply(ReplyType.VOICE, fileName)

+ 4 - 0
voice/factory.py

@@ -42,4 +42,8 @@ def create_voice(voice_type):
         from voice.ali.ali_voice import AliVoice
         from voice.ali.ali_voice import AliVoice
 
 
         return AliVoice()
         return AliVoice()
+    elif voice_type == "edge":
+        from voice.edge.edge_voice import EdgeVoice
+
+        return EdgeVoice()
     raise RuntimeError
     raise RuntimeError