|
@@ -30,6 +30,9 @@ import random
|
|
|
# 消息队列 map
|
|
# 消息队列 map
|
|
|
queue_map = dict()
|
|
queue_map = dict()
|
|
|
|
|
|
|
|
|
|
+# 响应队列 map
|
|
|
|
|
+reply_map = dict()
|
|
|
|
|
+
|
|
|
|
|
|
|
|
class XunFeiBot(Bot):
|
|
class XunFeiBot(Bot):
|
|
|
def __init__(self):
|
|
def __init__(self):
|
|
@@ -43,7 +46,6 @@ class XunFeiBot(Bot):
|
|
|
self.spark_url = "ws://spark-api.xf-yun.com/v2.1/chat"
|
|
self.spark_url = "ws://spark-api.xf-yun.com/v2.1/chat"
|
|
|
self.host = urlparse(self.spark_url).netloc
|
|
self.host = urlparse(self.spark_url).netloc
|
|
|
self.path = urlparse(self.spark_url).path
|
|
self.path = urlparse(self.spark_url).path
|
|
|
- self.answer = ""
|
|
|
|
|
# 和wenxin使用相同的session机制
|
|
# 和wenxin使用相同的session机制
|
|
|
self.sessions = SessionManager(BaiduWenxinSession, model=const.XUNFEI)
|
|
self.sessions = SessionManager(BaiduWenxinSession, model=const.XUNFEI)
|
|
|
|
|
|
|
@@ -52,6 +54,7 @@ class XunFeiBot(Bot):
|
|
|
logger.info("[XunFei] query={}".format(query))
|
|
logger.info("[XunFei] query={}".format(query))
|
|
|
session_id = context["session_id"]
|
|
session_id = context["session_id"]
|
|
|
request_id = self.gen_request_id(session_id)
|
|
request_id = self.gen_request_id(session_id)
|
|
|
|
|
+ reply_map[request_id] = ""
|
|
|
session = self.sessions.session_query(query, session_id)
|
|
session = self.sessions.session_query(query, session_id)
|
|
|
threading.Thread(target=self.create_web_socket, args=(session.messages, request_id)).start()
|
|
threading.Thread(target=self.create_web_socket, args=(session.messages, request_id)).start()
|
|
|
depth = 0
|
|
depth = 0
|
|
@@ -70,19 +73,20 @@ class XunFeiBot(Bot):
|
|
|
# 请求结束
|
|
# 请求结束
|
|
|
del queue_map[request_id]
|
|
del queue_map[request_id]
|
|
|
if data_item.reply:
|
|
if data_item.reply:
|
|
|
- self.answer += data_item.reply
|
|
|
|
|
|
|
+ reply_map[request_id] += data_item.reply
|
|
|
usage = data_item.usage
|
|
usage = data_item.usage
|
|
|
break
|
|
break
|
|
|
|
|
|
|
|
- self.answer += data_item.reply
|
|
|
|
|
|
|
+ reply_map[request_id] += data_item.reply
|
|
|
depth += 1
|
|
depth += 1
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
|
depth += 1
|
|
depth += 1
|
|
|
continue
|
|
continue
|
|
|
t2 = time.time()
|
|
t2 = time.time()
|
|
|
- logger.info(f"[XunFei-API] response={self.answer}, time={t2 - t1}s, usage={usage}")
|
|
|
|
|
- self.sessions.session_reply(self.answer, session_id, usage.get("total_tokens"))
|
|
|
|
|
- reply = Reply(ReplyType.TEXT, self.answer)
|
|
|
|
|
|
|
+ logger.info(f"[XunFei-API] response={reply_map[request_id]}, time={t2 - t1}s, usage={usage}")
|
|
|
|
|
+ self.sessions.session_reply(reply_map[request_id], session_id, usage.get("total_tokens"))
|
|
|
|
|
+ reply = Reply(ReplyType.TEXT, reply_map[request_id])
|
|
|
|
|
+ del reply_map[request_id]
|
|
|
return reply
|
|
return reply
|
|
|
else:
|
|
else:
|
|
|
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
|
|
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
|