소스 검색

Merge pull request #1488 from yy1781051483/master

add xunfei v3.0
zhayujie 2 년 전
부모
커밋
061d8a3a5f
1개의 변경된 파일34개의 추가작업 그리고 20개의 파일을 삭제
  1. 34 20
      bot/xunfei/xunfei_spark_bot.py

+ 34 - 20
bot/xunfei/xunfei_spark_bot.py

@@ -40,10 +40,11 @@ class XunFeiBot(Bot):
         self.app_id = conf().get("xunfei_app_id")
         self.api_key = conf().get("xunfei_api_key")
         self.api_secret = conf().get("xunfei_api_secret")
-        # 默认使用v2.0版本,1.5版本可设置为 general
+        # 默认使用v3.0版本,2.0版本可设置为generalv2  1.5版本可设置为 general
         self.domain = "generalv2"
-        # 默认使用v2.0版本,1.5版本可设置为 "ws://spark-api.xf-yun.com/v1.1/chat"
-        self.spark_url = "ws://spark-api.xf-yun.com/v2.1/chat"
+        # 默认使用v3.0版本,1.5版本可设置为 "ws://spark-api.xf-yun.com/v1.1/chat",
+        # 2.0版本可设置为 "ws://spark-api.xf-yun.com/v2.1/chat"
+        self.spark_url = "ws://spark-api.xf-yun.com/v3.1/chat"
         self.host = urlparse(self.spark_url).netloc
         self.path = urlparse(self.spark_url).path
         # 和wenxin使用相同的session机制
@@ -56,7 +57,8 @@ class XunFeiBot(Bot):
             request_id = self.gen_request_id(session_id)
             reply_map[request_id] = ""
             session = self.sessions.session_query(query, session_id)
-            threading.Thread(target=self.create_web_socket, args=(session.messages, request_id)).start()
+            threading.Thread(target=self.create_web_socket,
+                             args=(session.messages, request_id)).start()
             depth = 0
             time.sleep(0.1)
             t1 = time.time()
@@ -83,20 +85,27 @@ class XunFeiBot(Bot):
                     depth += 1
                     continue
             t2 = time.time()
-            logger.info(f"[XunFei-API] response={reply_map[request_id]}, time={t2 - t1}s, usage={usage}")
-            self.sessions.session_reply(reply_map[request_id], session_id, usage.get("total_tokens"))
+            logger.info(
+                f"[XunFei-API] response={reply_map[request_id]}, time={t2 - t1}s, usage={usage}"
+            )
+            self.sessions.session_reply(reply_map[request_id], session_id,
+                                        usage.get("total_tokens"))
             reply = Reply(ReplyType.TEXT, reply_map[request_id])
             del reply_map[request_id]
             return reply
         else:
-            reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
+            reply = Reply(ReplyType.ERROR,
+                          "Bot不支持处理{}类型的消息".format(context.type))
             return reply
 
     def create_web_socket(self, prompt, session_id, temperature=0.5):
         logger.info(f"[XunFei] start connect, prompt={prompt}")
         websocket.enableTrace(False)
         wsUrl = self.create_url()
-        ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close,
+        ws = websocket.WebSocketApp(wsUrl,
+                                    on_message=on_message,
+                                    on_error=on_error,
+                                    on_close=on_close,
                                     on_open=on_open)
         data_queue = queue.Queue(1000)
         queue_map[session_id] = data_queue
@@ -108,7 +117,8 @@ class XunFeiBot(Bot):
         ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
 
     def gen_request_id(self, session_id: str):
-        return session_id + "_" + str(int(time.time())) + "" + str(random.randint(0, 100))
+        return session_id + "_" + str(int(time.time())) + "" + str(
+            random.randint(0, 100))
 
     # 生成url
     def create_url(self):
@@ -122,22 +132,21 @@ class XunFeiBot(Bot):
         signature_origin += "GET " + self.path + " HTTP/1.1"
 
         # 进行hmac-sha256进行加密
-        signature_sha = hmac.new(self.api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
+        signature_sha = hmac.new(self.api_secret.encode('utf-8'),
+                                 signature_origin.encode('utf-8'),
                                  digestmod=hashlib.sha256).digest()
 
-        signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
+        signature_sha_base64 = base64.b64encode(signature_sha).decode(
+            encoding='utf-8')
 
         authorization_origin = f'api_key="{self.api_key}", algorithm="hmac-sha256", headers="host date request-line", ' \
                                f'signature="{signature_sha_base64}"'
 
-        authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
+        authorization = base64.b64encode(
+            authorization_origin.encode('utf-8')).decode(encoding='utf-8')
 
         # 将请求的鉴权参数组合为字典
-        v = {
-            "authorization": authorization,
-            "date": date,
-            "host": self.host
-        }
+        v = {"authorization": authorization, "date": date, "host": self.host}
         # 拼接鉴权参数,生成url
         url = self.spark_url + '?' + urlencode(v)
         # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
@@ -190,11 +199,15 @@ def on_close(ws, one, two):
 # 收到websocket连接建立的处理
 def on_open(ws):
     logger.info(f"[XunFei] Start websocket, session_id={ws.session_id}")
-    thread.start_new_thread(run, (ws,))
+    thread.start_new_thread(run, (ws, ))
 
 
 def run(ws, *args):
-    data = json.dumps(gen_params(appid=ws.appid, domain=ws.domain, question=ws.question, temperature=ws.temperature))
+    data = json.dumps(
+        gen_params(appid=ws.appid,
+                   domain=ws.domain,
+                   question=ws.question,
+                   temperature=ws.temperature))
     ws.send(data)
 
 
@@ -212,7 +225,8 @@ def on_message(ws, message):
         content = choices["text"][0]["content"]
         data_queue = queue_map.get(ws.session_id)
         if not data_queue:
-            logger.error(f"[XunFei] can't find data queue, session_id={ws.session_id}")
+            logger.error(
+                f"[XunFei] can't find data queue, session_id={ws.session_id}")
             return
         reply_item = ReplyItem(content)
         if status == 2: