Ver Fonte

feat: add xunfei spark bot

zhayujie há 2 anos atrás
pai
commit
a086f1989f

+ 12 - 3
README.md

@@ -5,7 +5,7 @@
 最新版本支持的功能如下:
 
 - [x] **多端部署:** 有多种部署方式可选择且功能完备,目前已支持个人微信,微信公众号和企业微信应用等部署方式
-- [x] **基础对话:** 私聊及群聊的消息智能回复,支持多轮会话上下文记忆,支持 GPT-3, GPT-3.5, GPT-4, 文心一言模型
+- [x] **基础对话:** 私聊及群聊的消息智能回复,支持多轮会话上下文记忆,支持 GPT-3, GPT-3.5, GPT-4, 文心一言, 讯飞星火
 - [x] **语音识别:** 可识别语音消息,通过文字或语音回复,支持 azure, baidu, google, openai等多种语音模型
 - [x] **图片生成:** 支持图片生成 和 图生图(如照片修复),可选择 Dell-E, stable diffusion, replicate, midjourney模型
 - [x] **丰富插件:** 支持个性化插件扩展,已实现多角色切换、文字冒险、敏感词过滤、聊天记录总结等插件
@@ -113,7 +113,7 @@ pip3 install azure-cognitiveservices-speech
 # config.json文件内容示例
 {
   "open_ai_api_key": "YOUR API KEY",                          # 填入上面创建的 OpenAI API KEY
-  "model": "gpt-3.5-turbo",                                   # 模型名称。当use_azure_chatgpt为true时,其名称为Azure上model deployment名称
+  "model": "gpt-3.5-turbo",                                   # 模型名称, 支持 gpt-3.5-turbo, gpt-3.5-turbo-16k, gpt-4, wenxin, xunfei
   "proxy": "",                                                # 代理客户端的ip和端口,国内环境开启代理的需要填写该项,如 "127.0.0.1:7890"
   "single_chat_prefix": ["bot", "@bot"],                      # 私聊时文本需要包含该前缀才能触发机器人回复
   "single_chat_reply_prefix": "[bot] ",                       # 私聊时自动回复的前缀,用于区分真人
@@ -129,7 +129,10 @@ pip3 install azure-cognitiveservices-speech
   "azure_api_version": "",                                    # 采用Azure ChatGPT时,API版本
   "character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。",  # 人格描述
   # 订阅消息,公众号和企业微信channel中请填写,当被订阅时会自动回复,可使用特殊占位符。目前支持的占位符有{trigger_prefix},在程序中它会自动替换成bot的触发词。
-  "subscribe_msg": "感谢您的关注!\n这里是ChatGPT,可以自由对话。\n支持语音对话。\n支持图片输出,画字开头的消息将按要求创作图片。\n支持角色扮演和文字冒险等丰富插件。\n输入{trigger_prefix}#help 查看详细指令。"
+  "subscribe_msg": "感谢您的关注!\n这里是ChatGPT,可以自由对话。\n支持语音对话。\n支持图片输出,画字开头的消息将按要求创作图片。\n支持角色扮演和文字冒险等丰富插件。\n输入{trigger_prefix}#help 查看详细指令。",
+  "use_linkai": false,                                        # 是否使用LinkAI接口,默认关闭,开启后可国内访问,使用知识库和MJ
+  "linkai_api_key": "",                                       # LinkAI Api Key
+  "linkai_app_code": ""                                       # LinkAI 应用code
 }
 ```
 **配置说明:**
@@ -166,6 +169,12 @@ pip3 install azure-cognitiveservices-speech
 + `character_desc` 配置中保存着你对机器人说的一段话,他会记住这段话并作为他的设定,你可以为他定制任何人格      (关于会话上下文的更多内容参考该 [issue](https://github.com/zhayujie/chatgpt-on-wechat/issues/43))
 + `subscribe_msg`:订阅消息,公众号和企业微信channel中请填写,当被订阅时会自动回复, 可使用特殊占位符。目前支持的占位符有{trigger_prefix},在程序中它会自动替换成bot的触发词。
 
+**5.LinkAI配置 (可选)**
+
++ `use_linkai`: 是否使用LinkAI接口,开启后可国内访问,使用知识库和 `Midjourney` 绘画, 参考 [文档](https://link-ai.tech/platform/link-app/wechat)
++ `linkai_api_key`: LinkAI Api Key,可在 [控制台](https://chat.link-ai.tech/console/interface) 创建
++ `linkai_app_code`: LinkAI 应用code,选填
+
 **本说明文档可能会未及时更新,当前所有可选的配置项均在该[`config.py`](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/config.py)中列出。**
 
 ## 运行

+ 4 - 6
bot/bot_factory.py

@@ -14,31 +14,29 @@ def create_bot(bot_type):
         # 替换Baidu Unit为Baidu文心千帆对话接口
         # from bot.baidu.baidu_unit_bot import BaiduUnitBot
         # return BaiduUnitBot()
-
         from bot.baidu.baidu_wenxin import BaiduWenxinBot
-
         return BaiduWenxinBot()
 
     elif bot_type == const.CHATGPT:
         # ChatGPT 网页端web接口
         from bot.chatgpt.chat_gpt_bot import ChatGPTBot
-
         return ChatGPTBot()
 
     elif bot_type == const.OPEN_AI:
         # OpenAI 官方对话模型API
         from bot.openai.open_ai_bot import OpenAIBot
-
         return OpenAIBot()
 
     elif bot_type == const.CHATGPTONAZURE:
         # Azure chatgpt service https://azure.microsoft.com/en-in/products/cognitive-services/openai-service/
         from bot.chatgpt.chat_gpt_bot import AzureChatGPTBot
-
         return AzureChatGPTBot()
 
+    elif bot_type == const.XUNFEI:
+        from bot.xunfei.xunfei_spark_bot import XunFeiBot
+        return XunFeiBot()
+
     elif bot_type == const.LINKAI:
         from bot.linkai.link_ai_bot import LinkAIBot
         return LinkAIBot()
-
     raise RuntimeError

+ 14 - 1
bot/chatgpt/chat_gpt_session.py

@@ -55,11 +55,16 @@ class ChatGPTSession(Session):
 # refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
 def num_tokens_from_messages(messages, model):
     """Returns the number of tokens used by a list of messages."""
+
+    if model in ["wenxin", "xunfei"]:
+        return num_tokens_by_character(messages)
+
     import tiktoken
 
     if model in ["gpt-3.5-turbo-0301", "gpt-35-turbo"]:
         return num_tokens_from_messages(messages, model="gpt-3.5-turbo")
-    elif model in ["gpt-4-0314", "gpt-4-0613", "gpt-4-32k", "gpt-4-32k-0613", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613", "gpt-35-turbo-16k"]:
+    elif model in ["gpt-4-0314", "gpt-4-0613", "gpt-4-32k", "gpt-4-32k-0613", "gpt-3.5-turbo-0613",
+                   "gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613", "gpt-35-turbo-16k"]:
         return num_tokens_from_messages(messages, model="gpt-4")
 
     try:
@@ -85,3 +90,11 @@ def num_tokens_from_messages(messages, model):
                 num_tokens += tokens_per_name
     num_tokens += 3  # every reply is primed with <|start|>assistant<|message|>
     return num_tokens
+
+
+def num_tokens_by_character(messages):
+    """Returns the number of tokens used by a list of messages."""
+    tokens = 0
+    for msg in messages:
+        tokens += len(msg["content"])
+    return tokens

+ 246 - 0
bot/xunfei/xunfei_spark_bot.py

@@ -0,0 +1,246 @@
+# encoding:utf-8
+
+import requests, json
+from bot.bot import Bot
+from bot.session_manager import SessionManager
+from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
+from bridge.context import ContextType, Context
+from bridge.reply import Reply, ReplyType
+from common.log import logger
+from config import conf
+from common import const
+import time
+import _thread as thread
+import datetime
+from datetime import datetime
+from wsgiref.handlers import format_date_time
+from urllib.parse import urlencode
+import base64
+import ssl
+import hashlib
+import hmac
+import json
+from time import mktime
+from urllib.parse import urlparse
+import websocket
+import queue
+import threading
+import random
+
+# 消息队列 map
+queue_map = dict()
+
+
+class XunFeiBot(Bot):
+    def __init__(self):
+        super().__init__()
+        self.app_id = conf().get("xunfei_app_id")
+        self.api_key = conf().get("xunfei_api_key")
+        self.api_secret = conf().get("xunfei_api_secret")
+        # 默认使用v2.0版本,1.5版本可设置为 general
+        self.domain = "generalv2"
+        # 默认使用v2.0版本,1.5版本可设置为 "ws://spark-api.xf-yun.com/v1.1/chat"
+        self.spark_url = "ws://spark-api.xf-yun.com/v2.1/chat"
+        self.host = urlparse(self.spark_url).netloc
+        self.path = urlparse(self.spark_url).path
+        self.answer = ""
+        # 和wenxin使用相同的session机制
+        self.sessions = SessionManager(BaiduWenxinSession, model=const.XUNFEI)
+
+    def reply(self, query, context: Context = None) -> Reply:
+        if context.type == ContextType.TEXT:
+            logger.info("[XunFei] query={}".format(query))
+            session_id = context["session_id"]
+            request_id = self.gen_request_id(session_id)
+            session = self.sessions.session_query(query, session_id)
+            threading.Thread(target=self.create_web_socket, args=(session.messages, request_id)).start()
+            depth = 0
+            time.sleep(0.1)
+            t1 = time.time()
+            usage = {}
+            while depth <= 300:
+                try:
+                    data_queue = queue_map.get(request_id)
+                    if not data_queue:
+                        depth += 1
+                        time.sleep(0.1)
+                        continue
+                    data_item = data_queue.get(block=True, timeout=0.1)
+                    if data_item.is_end:
+                        # 请求结束
+                        del queue_map[request_id]
+                        if data_item.reply:
+                            self.answer += data_item.reply
+                        usage = data_item.usage
+                        break
+
+                    self.answer += data_item.reply
+                    depth += 1
+                except Exception as e:
+                    depth += 1
+                    continue
+            t2 = time.time()
+            logger.info(f"[XunFei-API] response={self.answer}, time={t2 - t1}s, usage={usage}")
+            self.sessions.session_reply(self.answer, session_id, usage.get("total_tokens"))
+            reply = Reply(ReplyType.TEXT, self.answer)
+            return reply
+        else:
+            reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
+            return reply
+
+    def create_web_socket(self, prompt, session_id, temperature=0.5):
+        logger.info(f"[XunFei] start connect, prompt={prompt}")
+        websocket.enableTrace(False)
+        wsUrl = self.create_url()
+        ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close,
+                                    on_open=on_open)
+        data_queue = queue.Queue(1000)
+        queue_map[session_id] = data_queue
+        ws.appid = self.app_id
+        ws.question = prompt
+        ws.domain = self.domain
+        ws.session_id = session_id
+        ws.temperature = temperature
+        ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
+
+    def gen_request_id(self, session_id: str):
+        return session_id + "_" + str(int(time.time())) + "" + str(random.randint(0, 100))
+
+    # 生成url
+    def create_url(self):
+        # 生成RFC1123格式的时间戳
+        now = datetime.now()
+        date = format_date_time(mktime(now.timetuple()))
+
+        # 拼接字符串
+        signature_origin = "host: " + self.host + "\n"
+        signature_origin += "date: " + date + "\n"
+        signature_origin += "GET " + self.path + " HTTP/1.1"
+
+        # 进行hmac-sha256进行加密
+        signature_sha = hmac.new(self.api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
+                                 digestmod=hashlib.sha256).digest()
+
+        signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
+
+        authorization_origin = f'api_key="{self.api_key}", algorithm="hmac-sha256", headers="host date request-line", ' \
+                               f'signature="{signature_sha_base64}"'
+
+        authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
+
+        # 将请求的鉴权参数组合为字典
+        v = {
+            "authorization": authorization,
+            "date": date,
+            "host": self.host
+        }
+        # 拼接鉴权参数,生成url
+        url = self.spark_url + '?' + urlencode(v)
+        # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
+        return url
+
+    def gen_params(self, appid, domain, question):
+        """
+        通过appid和用户的提问来生成请参数
+        """
+        data = {
+            "header": {
+                "app_id": appid,
+                "uid": "1234"
+            },
+            "parameter": {
+                "chat": {
+                    "domain": domain,
+                    "random_threshold": 0.5,
+                    "max_tokens": 2048,
+                    "auditing": "default"
+                }
+            },
+            "payload": {
+                "message": {
+                    "text": question
+                }
+            }
+        }
+        return data
+
+
+class ReplyItem:
+    def __init__(self, reply, usage=None, is_end=False):
+        self.is_end = is_end
+        self.reply = reply
+        self.usage = usage
+
+
+# 收到websocket错误的处理
+def on_error(ws, error):
+    logger.error("[XunFei] error:", error)
+
+
+# 收到websocket关闭的处理
+def on_close(ws, one, two):
+    data_queue = queue_map.get(ws.session_id)
+    data_queue.put("END")
+
+
+# 收到websocket连接建立的处理
+def on_open(ws):
+    logger.info(f"[XunFei] Start websocket, session_id={ws.session_id}")
+    thread.start_new_thread(run, (ws,))
+
+
+def run(ws, *args):
+    data = json.dumps(gen_params(appid=ws.appid, domain=ws.domain, question=ws.question, temperature=ws.temperature))
+    ws.send(data)
+
+
+# Websocket 操作
+# 收到websocket消息的处理
+def on_message(ws, message):
+    data = json.loads(message)
+    code = data['header']['code']
+    if code != 0:
+        logger.error(f'请求错误: {code}, {data}')
+        ws.close()
+    else:
+        choices = data["payload"]["choices"]
+        status = choices["status"]
+        content = choices["text"][0]["content"]
+        data_queue = queue_map.get(ws.session_id)
+        if not data_queue:
+            logger.error(f"[XunFei] can't find data queue, session_id={ws.session_id}")
+            return
+        reply_item = ReplyItem(content)
+        if status == 2:
+            usage = data["payload"].get("usage")
+            reply_item = ReplyItem(content, usage)
+            reply_item.is_end = True
+            ws.close()
+        data_queue.put(reply_item)
+
+
+def gen_params(appid, domain, question, temperature=0.5):
+    """
+    通过appid和用户的提问来生成请参数
+    """
+    data = {
+        "header": {
+            "app_id": appid,
+            "uid": "1234"
+        },
+        "parameter": {
+            "chat": {
+                "domain": domain,
+                "temperature": temperature,
+                "random_threshold": 0.5,
+                "max_tokens": 2048,
+                "auditing": "default"
+            }
+        },
+        "payload": {
+            "message": {
+                "text": question
+            }
+        }
+    }
+    return data

+ 2 - 0
bridge/bridge.py

@@ -25,6 +25,8 @@ class Bridge(object):
             self.btype["chat"] = const.CHATGPTONAZURE
         if model_type in ["wenxin"]:
             self.btype["chat"] = const.BAIDU
+        if model_type in ["xunfei"]:
+            self.btype["chat"] = const.XUNFEI
         if conf().get("use_linkai") and conf().get("linkai_api_key"):
             self.btype["chat"] = const.LINKAI
         self.bots = {}

+ 1 - 0
common/const.py

@@ -2,6 +2,7 @@
 OPEN_AI = "openAI"
 CHATGPT = "chatGPT"
 BAIDU = "baidu"
+XUNFEI = "xunfei"
 CHATGPTONAZURE = "chatGPTOnAzure"
 LINKAI = "linkai"
 

+ 8 - 4
config.py

@@ -16,7 +16,7 @@ available_setting = {
     "open_ai_api_base": "https://api.openai.com/v1",
     "proxy": "",  # openai使用的代理
     # chatgpt模型, 当use_azure_chatgpt为true时,其名称为Azure上model deployment名称
-    "model": "gpt-3.5-turbo",    # 还支持 gpt-3.5-turbo-16k, gpt-4, wenxin
+    "model": "gpt-3.5-turbo",  # 还支持 gpt-3.5-turbo-16k, gpt-4, wenxin, xunfei
     "use_azure_chatgpt": False,  # 是否使用azure的chatgpt
     "azure_deployment_id": "",  # azure 模型部署名称
     "azure_api_version": "",  # azure api版本
@@ -52,9 +52,13 @@ available_setting = {
     "request_timeout": 60,  # chatgpt请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
     "timeout": 120,  # chatgpt重试超时时间,在这个时间内,将会自动重试
     # Baidu 文心一言参数
-    "baidu_wenxin_model": "eb-instant", # 默认使用ERNIE-Bot-turbo模型
-    "baidu_wenxin_api_key": "", # Baidu api key
-    "baidu_wenxin_secret_key": "", # Baidu secret key
+    "baidu_wenxin_model": "eb-instant",  # 默认使用ERNIE-Bot-turbo模型
+    "baidu_wenxin_api_key": "",  # Baidu api key
+    "baidu_wenxin_secret_key": "",  # Baidu secret key
+    # 讯飞星火API
+    "xunfei_app_id": "",  # 讯飞应用ID
+    "xunfei_api_key": "",  # 讯飞 API key
+    "xunfei_api_secret": "",  # 讯飞 API secret
     # 语音设置
     "speech_recognition": False,  # 是否开启语音识别
     "group_speech_recognition": False,  # 是否开启群组语音识别

+ 3 - 2
plugins/godcmd/godcmd.py

@@ -294,7 +294,7 @@ class Godcmd(Plugin):
                     except Exception as e:
                         ok, result = False, "你没有设置私有GPT模型"
                 elif cmd == "reset":
-                    if bottype in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI]:
+                    if bottype in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI, const.BAIDU, const.XUNFEI]:
                         bot.sessions.clear_session(session_id)
                         channel.cancel_session(session_id)
                         ok, result = True, "会话已重置"
@@ -317,7 +317,8 @@ class Godcmd(Plugin):
                             load_config()
                             ok, result = True, "配置已重载"
                         elif cmd == "resetall":
-                            if bottype in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI]:
+                            if bottype in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI,
+                                           const.BAIDU, const.XUNFEI]:
                                 channel.cancel_all_session()
                                 bot.sessions.clear_all_session()
                                 ok, result = True, "重置所有会话成功"