Przeglądaj źródła

Merge branch 'master' into master

zhayujie 2 lat temu
rodzic
commit
eda3ba92fd

+ 19 - 14
app.py

@@ -3,6 +3,7 @@
 import os
 import os
 import signal
 import signal
 import sys
 import sys
+import time
 
 
 from channel import channel_factory
 from channel import channel_factory
 from common import const
 from common import const
@@ -24,6 +25,21 @@ def sigterm_handler_wrap(_signo):
     signal.signal(_signo, func)
     signal.signal(_signo, func)
 
 
 
 
+def start_channel(channel_name: str):
+    channel = channel_factory.create_channel(channel_name)
+    if channel_name in ["wx", "wxy", "terminal", "wechatmp", "wechatmp_service", "wechatcom_app", "wework",
+                        const.FEISHU, const.DINGTALK]:
+        PluginManager().load_plugins()
+
+    if conf().get("use_linkai"):
+        try:
+            from common import linkai_client
+            threading.Thread(target=linkai_client.start, args=(channel,)).start()
+        except Exception as e:
+            pass
+    channel.startup()
+
+
 def run():
 def run():
     try:
     try:
         # load config
         # load config
@@ -41,22 +57,11 @@ def run():
 
 
         if channel_name == "wxy":
         if channel_name == "wxy":
             os.environ["WECHATY_LOG"] = "warn"
             os.environ["WECHATY_LOG"] = "warn"
-            # os.environ['WECHATY_PUPPET_SERVICE_ENDPOINT'] = '127.0.0.1:9001'
-
-        channel = channel_factory.create_channel(channel_name)
-        if channel_name in ["wx", "wxy", "terminal", "wechatmp", "wechatmp_service", "wechatcom_app", "wework", const.FEISHU,const.DINGTALK]:
-            PluginManager().load_plugins()
-
-        if conf().get("use_linkai"):
-            try:
-                from common import linkai_client
-                threading.Thread(target=linkai_client.start, args=(channel, )).start()
-            except Exception as e:
-                pass
 
 
-        # startup channel
-        channel.startup()
+        start_channel(channel_name)
 
 
+        while True:
+            time.sleep(1)
     except Exception as e:
     except Exception as e:
         logger.error("App startup failed!")
         logger.error("App startup failed!")
         logger.exception(e)
         logger.exception(e)

+ 1 - 0
bot/bot_factory.py

@@ -56,4 +56,5 @@ def create_bot(bot_type):
         from bot.zhipuai.zhipuai_bot import ZHIPUAIBot
         from bot.zhipuai.zhipuai_bot import ZHIPUAIBot
         return ZHIPUAIBot()
         return ZHIPUAIBot()
 
 
+
     raise RuntimeError
     raise RuntimeError

+ 6 - 2
bot/linkai/link_ai_bot.py

@@ -107,7 +107,11 @@ class LinkAIBot(Bot):
                             body["group_name"] = context.kwargs.get("msg").from_user_nickname
                             body["group_name"] = context.kwargs.get("msg").from_user_nickname
                             body["sender_name"] = context.kwargs.get("msg").actual_user_nickname
                             body["sender_name"] = context.kwargs.get("msg").actual_user_nickname
                         else:
                         else:
-                            body["sender_name"] = context.kwargs.get("msg").from_user_nickname
+                            if body.get("channel_type") in ["wechatcom_app"]:
+                                body["sender_name"] = context.kwargs.get("msg").from_user_id
+                            else:
+                                body["sender_name"] = context.kwargs.get("msg").from_user_nickname
+
             except Exception as e:
             except Exception as e:
                 pass
                 pass
             file_id = context.kwargs.get("file_id")
             file_id = context.kwargs.get("file_id")
@@ -396,7 +400,7 @@ class LinkAIBot(Bot):
                 i += 1
                 i += 1
                 if url.endswith(".mp4"):
                 if url.endswith(".mp4"):
                     reply_type = ReplyType.VIDEO_URL
                     reply_type = ReplyType.VIDEO_URL
-                elif url.endswith(".pdf") or url.endswith(".doc") or url.endswith(".docx"):
+                elif url.endswith(".pdf") or url.endswith(".doc") or url.endswith(".docx") or url.endswith(".csv"):
                     reply_type = ReplyType.FILE
                     reply_type = ReplyType.FILE
                     url = _download_file(url)
                     url = _download_file(url)
                     if not url:
                     if not url:

+ 155 - 0
bot/zhipu/chat_glm_bot.py

@@ -0,0 +1,155 @@
+# encoding:utf-8
+
+import time
+
+import openai
+import openai.error
+import requests
+
+from bot.bot import Bot
+from bot.zhipu.chat_glm_session import ChatGLMSession
+from bot.openai.open_ai_image import OpenAIImage
+from bot.session_manager import SessionManager
+from bridge.context import ContextType
+from bridge.reply import Reply, ReplyType
+from common.log import logger
+# from common.token_bucket import TokenBucket
+from config import conf, load_config
+from zhipuai import ZhipuAI
+
+
+# ZhipuAI对话模型API
+class ChatGLMBot(Bot):
+    def __init__(self):
+        super().__init__()
+        # set the default api_key
+        self.api_key = conf().get("zhipu_ai_api_key")
+        if conf().get("zhipu_ai_api_base"):
+            openai.api_base = conf().get("zhipu_ai_api_base")
+        # if conf().get("rate_limit_chatgpt"):
+        #     self.tb4chatgpt = TokenBucket(conf().get("rate_limit_chatgpt", 20))
+
+        self.sessions = SessionManager(ChatGLMSession, model=conf().get("model") or "chatglm")
+        self.args = {
+            "model": "glm-4",  # 对话模型的名称
+            "temperature": conf().get("temperature", 0.9),  # 值在[0,1]之间,越大表示回复越具有不确定性
+            # "max_tokens":4096,  # 回复最大的字符数
+            "top_p": conf().get("top_p", 0.7),
+            # "frequency_penalty": conf().get("frequency_penalty", 0.0),  # [-2,2]之间,该值越大则更倾向于产生不同的内容
+            # "presence_penalty": conf().get("presence_penalty", 0.0),  # [-2,2]之间,该值越大则更倾向于产生不同的内容
+            # "request_timeout": conf().get("request_timeout", None),  # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
+            # "timeout": conf().get("request_timeout", None),  # 重试超时时间,在这个时间内,将会自动重试
+        }
+        self.client = ZhipuAI(api_key=self.api_key)
+
+    def reply(self, query, context=None):
+        # acquire reply content
+        if context.type == ContextType.TEXT:
+            logger.info("[CHATGLM] query={}".format(query))
+
+            session_id = context["session_id"]
+            reply = None
+            clear_memory_commands = conf().get("clear_memory_commands", ["#清除记忆"])
+            if query in clear_memory_commands:
+                self.sessions.clear_session(session_id)
+                reply = Reply(ReplyType.INFO, "记忆已清除")
+            elif query == "#清除所有":
+                self.sessions.clear_all_session()
+                reply = Reply(ReplyType.INFO, "所有人记忆已清除")
+            elif query == "#更新配置":
+                load_config()
+                reply = Reply(ReplyType.INFO, "配置已更新")
+            if reply:
+                return reply
+            session = self.sessions.session_query(query, session_id)
+            logger.debug("[CHATGLM] session query={}".format(session.messages))
+
+            api_key = context.get("openai_api_key") or openai.api_key
+            model = context.get("gpt_model")
+            new_args = None
+            if model:
+                new_args = self.args.copy()
+                new_args["model"] = model
+            # if context.get('stream'):
+            #     # reply in stream
+            #     return self.reply_text_stream(query, new_query, session_id)
+
+            reply_content = self.reply_text(session, api_key, args=new_args)
+            logger.debug(
+                "[CHATGLM] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
+                    session.messages,
+                    session_id,
+                    reply_content["content"],
+                    reply_content["completion_tokens"],
+                )
+            )
+            if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0:
+                reply = Reply(ReplyType.ERROR, reply_content["content"])
+            elif reply_content["completion_tokens"] > 0:
+                self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"])
+                reply = Reply(ReplyType.TEXT, reply_content["content"])
+            else:
+                reply = Reply(ReplyType.ERROR, reply_content["content"])
+                logger.debug("[CHATGLM] reply {} used 0 tokens.".format(reply_content))
+            return reply
+        else:
+            reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
+            return reply
+
+    def reply_text(self, session: ChatGLMSession, api_key=None, args=None, retry_count=0) -> dict:
+        """
+        call openai's ChatCompletion to get the answer
+        :param session: a conversation session
+        :param session_id: session id
+        :param retry_count: retry count
+        :return: {}
+        """
+        try:
+            # if conf().get("rate_limit_chatgpt") and not self.tb4chatgpt.get_token():
+            #     raise openai.error.RateLimitError("RateLimitError: rate limit exceeded")
+            # if api_key == None, the default openai.api_key will be used
+            if args is None:
+                args = self.args
+            # response = openai.ChatCompletion.create(api_key=api_key, messages=session.messages, **args)
+            response = self.client.chat.completions.create(messages=session.messages, **args)
+            # logger.debug("[CHATGLM] response={}".format(response))
+            # logger.info("[CHATGLM] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
+            return {
+                "total_tokens": response.usage.total_tokens,
+                "completion_tokens": response.usage.completion_tokens,
+                "content": response.choices[0].message.content,
+            }
+        except Exception as e:
+            need_retry = retry_count < 2
+            result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
+            if isinstance(e, openai.error.RateLimitError):
+                logger.warn("[CHATGLM] RateLimitError: {}".format(e))
+                result["content"] = "提问太快啦,请休息一下再问我吧"
+                if need_retry:
+                    time.sleep(20)
+            elif isinstance(e, openai.error.Timeout):
+                logger.warn("[CHATGLM] Timeout: {}".format(e))
+                result["content"] = "我没有收到你的消息"
+                if need_retry:
+                    time.sleep(5)
+            elif isinstance(e, openai.error.APIError):
+                logger.warn("[CHATGLM] Bad Gateway: {}".format(e))
+                result["content"] = "请再问我一次"
+                if need_retry:
+                    time.sleep(10)
+            elif isinstance(e, openai.error.APIConnectionError):
+                logger.warn("[CHATGLM] APIConnectionError: {}".format(e))
+                result["content"] = "我连接不到你的网络"
+                if need_retry:
+                    time.sleep(5)
+            else:
+                logger.exception("[CHATGLM] Exception: {}".format(e), e)
+                need_retry = False
+                self.sessions.clear_session(session.session_id)
+
+            if need_retry:
+                logger.warn("[CHATGLM] 第{}次重试".format(retry_count + 1))
+                return self.reply_text(session, api_key, args, retry_count + 1)
+            else:
+                return result
+

+ 48 - 0
bot/zhipu/chat_glm_session.py

@@ -0,0 +1,48 @@
+from bot.session_manager import Session
+from common.log import logger
+
+class ChatGLMSession(Session):
+    def __init__(self, session_id, system_prompt=None, model="glm-4"):
+        super().__init__(session_id, system_prompt)
+        self.model = model
+        self.reset()
+
+    def discard_exceeding(self, max_tokens, cur_tokens=None):
+        precise = True
+        try:
+            cur_tokens = self.calc_tokens()
+        except Exception as e:
+            precise = False
+            if cur_tokens is None:
+                raise e
+            logger.debug("Exception when counting tokens precisely for query: {}".format(e))
+        while cur_tokens > max_tokens:
+            if len(self.messages) > 2:
+                self.messages.pop(1)
+            elif len(self.messages) == 2 and self.messages[1]["role"] == "assistant":
+                self.messages.pop(1)
+                if precise:
+                    cur_tokens = self.calc_tokens()
+                else:
+                    cur_tokens = cur_tokens - max_tokens
+                break
+            elif len(self.messages) == 2 and self.messages[1]["role"] == "user":
+                logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens))
+                break
+            else:
+                logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages)))
+                break
+            if precise:
+                cur_tokens = self.calc_tokens()
+            else:
+                cur_tokens = cur_tokens - max_tokens
+        return cur_tokens
+
+    def calc_tokens(self):
+        return num_tokens_from_messages(self.messages, self.model)
+
+def num_tokens_from_messages(messages, model):
+    tokens = 0
+    for msg in messages:
+        tokens += len(msg["content"])
+    return tokens

+ 4 - 2
channel/chat_channel.py

@@ -4,6 +4,7 @@ import threading
 import time
 import time
 from asyncio import CancelledError
 from asyncio import CancelledError
 from concurrent.futures import Future, ThreadPoolExecutor
 from concurrent.futures import Future, ThreadPoolExecutor
+from concurrent import futures
 
 
 from bridge.context import *
 from bridge.context import *
 from bridge.reply import *
 from bridge.reply import *
@@ -17,6 +18,8 @@ try:
 except Exception as e:
 except Exception as e:
     pass
     pass
 
 
+handler_pool = ThreadPoolExecutor(max_workers=8)  # 处理消息的线程池
+
 
 
 # 抽象类, 它包含了与消息通道无关的通用处理逻辑
 # 抽象类, 它包含了与消息通道无关的通用处理逻辑
 class ChatChannel(Channel):
 class ChatChannel(Channel):
@@ -25,7 +28,6 @@ class ChatChannel(Channel):
     futures = {}  # 记录每个session_id提交到线程池的future对象, 用于重置会话时把没执行的future取消掉,正在执行的不会被取消
     futures = {}  # 记录每个session_id提交到线程池的future对象, 用于重置会话时把没执行的future取消掉,正在执行的不会被取消
     sessions = {}  # 用于控制并发,每个session_id同时只能有一个context在处理
     sessions = {}  # 用于控制并发,每个session_id同时只能有一个context在处理
     lock = threading.Lock()  # 用于控制对sessions的访问
     lock = threading.Lock()  # 用于控制对sessions的访问
-    handler_pool = ThreadPoolExecutor(max_workers=8)  # 处理消息的线程池
 
 
     def __init__(self):
     def __init__(self):
         _thread = threading.Thread(target=self.consume)
         _thread = threading.Thread(target=self.consume)
@@ -339,7 +341,7 @@ class ChatChannel(Channel):
                         if not context_queue.empty():
                         if not context_queue.empty():
                             context = context_queue.get()
                             context = context_queue.get()
                             logger.debug("[WX] consume context: {}".format(context))
                             logger.debug("[WX] consume context: {}".format(context))
-                            future: Future = self.handler_pool.submit(self._handle, context)
+                            future: Future = handler_pool.submit(self._handle, context)
                             future.add_done_callback(self._thread_pool_callback(session_id, context=context))
                             future.add_done_callback(self._thread_pool_callback(session_id, context=context))
                             if session_id not in self.futures:
                             if session_id not in self.futures:
                                 self.futures[session_id] = []
                                 self.futures[session_id] = []

+ 38 - 25
channel/wechat/wechat_channel.py

@@ -15,6 +15,7 @@ import requests
 from bridge.context import *
 from bridge.context import *
 from bridge.reply import *
 from bridge.reply import *
 from channel.chat_channel import ChatChannel
 from channel.chat_channel import ChatChannel
+from channel import chat_channel
 from channel.wechat.wechat_message import *
 from channel.wechat.wechat_message import *
 from common.expired_dict import ExpiredDict
 from common.expired_dict import ExpiredDict
 from common.log import logger
 from common.log import logger
@@ -112,30 +113,39 @@ class WechatChannel(ChatChannel):
         self.auto_login_times = 0
         self.auto_login_times = 0
 
 
     def startup(self):
     def startup(self):
-        itchat.instance.receivingRetryCount = 600  # 修改断线超时时间
-        # login by scan QRCode
-        hotReload = conf().get("hot_reload", False)
-        status_path = os.path.join(get_appdata_dir(), "itchat.pkl")
-        itchat.auto_login(
-            enableCmdQR=2,
-            hotReload=hotReload,
-            statusStorageDir=status_path,
-            qrCallback=qrCallback,
-            exitCallback=self.exitCallback,
-            loginCallback=self.loginCallback
-        )
-        self.user_id = itchat.instance.storageClass.userName
-        self.name = itchat.instance.storageClass.nickName
-        logger.info("Wechat login success, user_id: {}, nickname: {}".format(self.user_id, self.name))
-        # start message listener
-        itchat.run()
+        try:
+            itchat.instance.receivingRetryCount = 600  # 修改断线超时时间
+            # login by scan QRCode
+            hotReload = conf().get("hot_reload", False)
+            status_path = os.path.join(get_appdata_dir(), "itchat.pkl")
+            itchat.auto_login(
+                enableCmdQR=2,
+                hotReload=hotReload,
+                statusStorageDir=status_path,
+                qrCallback=qrCallback,
+                exitCallback=self.exitCallback,
+                loginCallback=self.loginCallback
+            )
+            self.user_id = itchat.instance.storageClass.userName
+            self.name = itchat.instance.storageClass.nickName
+            logger.info("Wechat login success, user_id: {}, nickname: {}".format(self.user_id, self.name))
+            # start message listener
+            itchat.run()
+        except Exception as e:
+            logger.error(e)
 
 
     def exitCallback(self):
     def exitCallback(self):
-        _send_logout()
-        time.sleep(3)
-        self.auto_login_times += 1
-        if self.auto_login_times < 100:
-            self.startup()
+        try:
+            from common.linkai_client import chat_client
+            if chat_client.client_id and conf().get("use_linkai"):
+                _send_logout()
+                time.sleep(2)
+                self.auto_login_times += 1
+                if self.auto_login_times < 100:
+                    chat_channel.handler_pool._shutdown = False
+                    self.startup()
+        except Exception as e:
+            pass
 
 
     def loginCallback(self):
     def loginCallback(self):
         logger.debug("Login success")
         logger.debug("Login success")
@@ -251,20 +261,23 @@ class WechatChannel(ChatChannel):
 def _send_login_success():
 def _send_login_success():
     try:
     try:
         from common.linkai_client import chat_client
         from common.linkai_client import chat_client
-        chat_client.send_login_success()
+        if chat_client.client_id:
+            chat_client.send_login_success()
     except Exception as e:
     except Exception as e:
         pass
         pass
 
 
 def _send_logout():
 def _send_logout():
     try:
     try:
         from common.linkai_client import chat_client
         from common.linkai_client import chat_client
-        chat_client.send_logout()
+        if chat_client.client_id:
+            chat_client.send_logout()
     except Exception as e:
     except Exception as e:
         pass
         pass
 
 
 def _send_qr_code(qrcode_list: list):
 def _send_qr_code(qrcode_list: list):
     try:
     try:
         from common.linkai_client import chat_client
         from common.linkai_client import chat_client
-        chat_client.send_qrcode(qrcode_list)
+        if chat_client.client_id:
+            chat_client.send_qrcode(qrcode_list)
     except Exception as e:
     except Exception as e:
         pass
         pass

+ 1 - 0
common/const.py

@@ -10,6 +10,7 @@ QWEN = "qwen"
 GEMINI = "gemini"
 GEMINI = "gemini"
 ZHIPU_AI = "glm-4"
 ZHIPU_AI = "glm-4"
 
 
+
 # model
 # model
 GPT35 = "gpt-3.5-turbo"
 GPT35 = "gpt-3.5-turbo"
 GPT4 = "gpt-4"
 GPT4 = "gpt-4"

+ 26 - 1
common/linkai_client.py

@@ -2,7 +2,9 @@ from bridge.context import Context, ContextType
 from bridge.reply import Reply, ReplyType
 from bridge.reply import Reply, ReplyType
 from common.log import logger
 from common.log import logger
 from linkai import LinkAIClient, PushMsg
 from linkai import LinkAIClient, PushMsg
-from config import conf
+from config import conf, pconf, plugin_config
+from plugins import PluginManager
+
 
 
 chat_client: LinkAIClient
 chat_client: LinkAIClient
 
 
@@ -22,6 +24,29 @@ class ChatClient(LinkAIClient):
         context["isgroup"] = push_msg.is_group
         context["isgroup"] = push_msg.is_group
         self.channel.send(Reply(ReplyType.TEXT, content=msg_content), context)
         self.channel.send(Reply(ReplyType.TEXT, content=msg_content), context)
 
 
+    def on_config(self, config: dict):
+        if not self.client_id:
+            return
+        logger.info(f"从控制台加载配置: {config}")
+        local_config = conf()
+        for key in local_config.keys():
+            if config.get(key) is not None:
+                local_config[key] = config.get(key)
+        if config.get("reply_voice_mode"):
+            if config.get("reply_voice_mode") == "voice_reply_voice":
+                local_config["voice_reply_voice"] = True
+            elif config.get("reply_voice_mode") == "always_reply_voice":
+                local_config["always_reply_voice"] = True
+        # if config.get("admin_password") and plugin_config["Godcmd"]:
+        #     plugin_config["Godcmd"]["password"] = config.get("admin_password")
+        #     PluginManager().instances["Godcmd"].reload()
+        # if config.get("group_app_map") and pconf("linkai"):
+        #     local_group_map = {}
+        #     for mapping in config.get("group_app_map"):
+        #         local_group_map[mapping.get("group_name")] = mapping.get("app_code")
+        #     pconf("linkai")["group_app_map"] = local_group_map
+        #     PluginManager().instances["linkai"].reload()
+
 
 
 def start(channel):
 def start(channel):
     global chat_client
     global chat_client

+ 0 - 1
config.py

@@ -159,7 +159,6 @@ available_setting = {
     # 智谱AI 平台配置
     # 智谱AI 平台配置
     "zhipu_ai_api_key": "",
     "zhipu_ai_api_key": "",
     "zhipu_ai_api_base": "https://open.bigmodel.cn/api/paas/v4",
     "zhipu_ai_api_base": "https://open.bigmodel.cn/api/paas/v4",
-
 }
 }
 
 
 
 

+ 8 - 0
plugins/godcmd/godcmd.py

@@ -475,3 +475,11 @@ class Godcmd(Plugin):
         if model == "gpt-4-turbo":
         if model == "gpt-4-turbo":
             return const.GPT4_TURBO_PREVIEW
             return const.GPT4_TURBO_PREVIEW
         return model
         return model
+
+    def reload(self):
+        gconf = plugin_config[self.name]
+        if gconf:
+            if gconf.get("password"):
+                self.password = gconf["password"]
+            if gconf.get("admin_users"):
+                self.admin_users = gconf["admin_users"]

+ 3 - 0
plugins/plugin.py

@@ -46,3 +46,6 @@ class Plugin:
 
 
     def get_help_text(self, **kwargs):
     def get_help_text(self, **kwargs):
         return "暂无帮助信息"
         return "暂无帮助信息"
+
+    def reload(self):
+        pass

+ 1 - 1
requirements-optional.txt

@@ -39,4 +39,4 @@ linkai
 dingtalk_stream
 dingtalk_stream
 
 
 # zhipuai
 # zhipuai
-zhipuai>=2.0.1
+zhipuai>=2.0.1