zhayujie 2 роки тому
батько
коміт
c3f7e2645c

+ 2 - 1
.gitignore

@@ -29,4 +29,5 @@ plugins/banwords/lib/__pycache__
 !plugins/hello
 !plugins/role
 !plugins/keyword
-!plugins/linkai
+!plugins/linkai
+client_config.json

+ 9 - 0
app.py

@@ -8,6 +8,7 @@ from channel import channel_factory
 from common import const
 from config import load_config
 from plugins import *
+import threading
 
 
 def sigterm_handler_wrap(_signo):
@@ -46,8 +47,16 @@ def run():
         if channel_name in ["wx", "wxy", "terminal", "wechatmp", "wechatmp_service", "wechatcom_app", "wework", const.FEISHU]:
             PluginManager().load_plugins()
 
+        if conf().get("use_linkai"):
+            try:
+                from common import linkai_client
+                threading.Thread(target=linkai_client.start, args=(channel, )).start()
+            except Exception as e:
+                pass
+
         # startup channel
         channel.startup()
+
     except Exception as e:
         logger.error("App startup failed!")
         logger.exception(e)

+ 19 - 3
bot/linkai/link_ai_bot.py

@@ -17,7 +17,6 @@ import threading
 from common import memory, utils
 import base64
 
-
 class LinkAIBot(Bot):
     # authentication failed
     AUTH_FAILED_CODE = 401
@@ -84,7 +83,6 @@ class LinkAIBot(Bot):
             if session_message[0].get("role") == "system":
                 if app_code or model == "wenxin":
                     session_message.pop(0)
-
             body = {
                 "app_code": app_code,
                 "messages": session_message,
@@ -93,7 +91,25 @@ class LinkAIBot(Bot):
                 "top_p": conf().get("top_p", 1),
                 "frequency_penalty": conf().get("frequency_penalty", 0.0),  # [-2,2]之间,该值越大则更倾向于产生不同的内容
                 "presence_penalty": conf().get("presence_penalty", 0.0),  # [-2,2]之间,该值越大则更倾向于产生不同的内容
+                "session_id": session_id,
+                "channel_type": conf().get("channel_type")
             }
+            try:
+                from linkai import LinkAIClient
+                client_id = LinkAIClient.fetch_client_id()
+                if client_id:
+                    body["client_id"] = client_id
+                    # start: client info deliver
+                    if context.kwargs.get("msg"):
+                        body["session_id"] = context.kwargs.get("msg").from_user_id
+                        if context.kwargs.get("msg").is_group:
+                            body["is_group"] = True
+                            body["group_name"] = context.kwargs.get("msg").from_user_nickname
+                            body["sender_name"] = context.kwargs.get("msg").actual_user_nickname
+                        else:
+                            body["sender_name"] = context.kwargs.get("msg").from_user_nickname
+            except Exception as e:
+                pass
             file_id = context.kwargs.get("file_id")
             if file_id:
                 body["file_id"] = file_id
@@ -230,7 +246,7 @@ class LinkAIBot(Bot):
             }
             if self.args.get("max_tokens"):
                 body["max_tokens"] = self.args.get("max_tokens")
-            headers = {"Authorization": "Bearer " +  conf().get("linkai_api_key")}
+            headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
 
             # do http request
             base_url = conf().get("linkai_api_base", "https://api.link-ai.chat")

+ 1 - 0
channel/channel.py

@@ -8,6 +8,7 @@ from bridge.reply import *
 
 
 class Channel(object):
+    channel_type = ""
     NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE, ReplyType.IMAGE]
 
     def startup(self):

+ 16 - 18
channel/channel_factory.py

@@ -2,43 +2,41 @@
 channel factory
 """
 from common import const
+from .channel import Channel
 
-def create_channel(channel_type):
+
+def create_channel(channel_type) -> Channel:
     """
     create a channel instance
     :param channel_type: channel type code
     :return: channel instance
     """
+    ch = Channel()
     if channel_type == "wx":
         from channel.wechat.wechat_channel import WechatChannel
-
-        return WechatChannel()
+        ch = WechatChannel()
     elif channel_type == "wxy":
         from channel.wechat.wechaty_channel import WechatyChannel
-
-        return WechatyChannel()
+        ch = WechatyChannel()
     elif channel_type == "terminal":
         from channel.terminal.terminal_channel import TerminalChannel
-
-        return TerminalChannel()
+        ch = TerminalChannel()
     elif channel_type == "wechatmp":
         from channel.wechatmp.wechatmp_channel import WechatMPChannel
-
-        return WechatMPChannel(passive_reply=True)
+        ch = WechatMPChannel(passive_reply=True)
     elif channel_type == "wechatmp_service":
         from channel.wechatmp.wechatmp_channel import WechatMPChannel
-
-        return WechatMPChannel(passive_reply=False)
+        ch = WechatMPChannel(passive_reply=False)
     elif channel_type == "wechatcom_app":
         from channel.wechatcom.wechatcomapp_channel import WechatComAppChannel
-
-        return WechatComAppChannel()
+        ch = WechatComAppChannel()
     elif channel_type == "wework":
         from channel.wework.wework_channel import WeworkChannel
-        return WeworkChannel()
-
+        ch = WeworkChannel()
     elif channel_type == const.FEISHU:
         from channel.feishu.feishu_channel import FeiShuChanel
-        return FeiShuChanel()
-
-    raise RuntimeError
+        ch = FeiShuChanel()
+    else:
+        raise RuntimeError
+    ch.channel_type = channel_type
+    return ch

+ 8 - 4
channel/feishu/feishu_channel.py

@@ -51,10 +51,14 @@ class FeiShuChanel(ChatChannel):
         web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port))
 
     def send(self, reply: Reply, context: Context):
-        msg = context["msg"]
+        msg = context.get("msg")
         is_group = context["isgroup"]
+        if msg:
+            access_token = msg.access_token
+        else:
+            access_token = self.fetch_access_token()
         headers = {
-            "Authorization": "Bearer " + msg.access_token,
+            "Authorization": "Bearer " + access_token,
             "Content-Type": "application/json",
         }
         msg_type = "text"
@@ -63,7 +67,7 @@ class FeiShuChanel(ChatChannel):
         content_key = "text"
         if reply.type == ReplyType.IMAGE_URL:
             # 图片上传
-            reply_content = self._upload_image_url(reply.content, msg.access_token)
+            reply_content = self._upload_image_url(reply.content, access_token)
             if not reply_content:
                 logger.warning("[FeiShu] upload file failed")
                 return
@@ -79,7 +83,7 @@ class FeiShuChanel(ChatChannel):
             res = requests.post(url=url, headers=headers, json=data, timeout=(5, 10))
         else:
             url = "https://open.feishu.cn/open-apis/im/v1/messages"
-            params = {"receive_id_type": context.get("receive_id_type")}
+            params = {"receive_id_type": context.get("receive_id_type") or "open_id"}
             data = {
                 "receive_id": context.get("receiver"),
                 "msg_type": msg_type,

+ 11 - 0
channel/wechat/wechat_channel.py

@@ -109,6 +109,7 @@ class WechatChannel(ChatChannel):
     def __init__(self):
         super().__init__()
         self.receivedMsgs = ExpiredDict(60 * 60)
+        self.auto_login_times = 0
 
     def startup(self):
         itchat.instance.receivingRetryCount = 600  # 修改断线超时时间
@@ -120,6 +121,8 @@ class WechatChannel(ChatChannel):
             hotReload=hotReload,
             statusStorageDir=status_path,
             qrCallback=qrCallback,
+            exitCallback=self.exitCallback,
+            loginCallback=self.loginCallback
         )
         self.user_id = itchat.instance.storageClass.userName
         self.name = itchat.instance.storageClass.nickName
@@ -127,6 +130,14 @@ class WechatChannel(ChatChannel):
         # start message listener
         itchat.run()
 
+    def exitCallback(self):
+        self.auto_login_times += 1
+        if self.auto_login_times < 100:
+            self.startup()
+
+    def loginCallback(self):
+        pass
+
     # handle_* 系列函数处理收到的消息后构造Context,然后传入produce函数中处理Context和发送回复
     # Context包含了消息的所有信息,包括以下属性
     #   type 消息类型, 包括TEXT、VOICE、IMAGE_CREATE

+ 28 - 0
common/linkai_client.py

@@ -0,0 +1,28 @@
+from bridge.context import Context, ContextType
+from bridge.reply import Reply, ReplyType
+from common.log import logger
+from linkai import LinkAIClient, PushMsg
+from config import conf
+
+
+class ChatClient(LinkAIClient):
+    def __init__(self, api_key, host, channel):
+        super().__init__(api_key, host)
+        self.channel = channel
+        self.client_type = channel.channel_type
+
+    def on_message(self, push_msg: PushMsg):
+        session_id = push_msg.session_id
+        msg_content = push_msg.msg_content
+        logger.info(f"receive msg push, session_id={session_id}, msg_content={msg_content}")
+        context = Context()
+        context.type = ContextType.TEXT
+        context["receiver"] = session_id
+        context["isgroup"] = push_msg.is_group
+        self.channel.send(Reply(ReplyType.TEXT, content=msg_content), context)
+
+
+def start(channel):
+    client = ChatClient(api_key=conf().get("linkai_api_key"),
+                        host="link-ai.chat", channel=channel)
+    client.start()