Răsfoiți Sursa

fix: wechatmp's deadloop when reply is None

JS00000 3 ani în urmă
părinte
comite
8ee7a48151

+ 5 - 0
channel/chat_channel.py

@@ -233,6 +233,9 @@ class ChatChannel(Channel):
                 time.sleep(3+3*retry_cnt)
                 self._send(reply, context, retry_cnt+1)
 
+    def _success_callback(self, session_id, **kwargs): # 线程正常结束时的回调函数
+        pass
+
     def _fail_callback(self, session_id, exception, **kwargs): # 线程异常结束时的回调函数
         logger.exception("Worker return exception: {}".format(exception))
 
@@ -242,6 +245,8 @@ class ChatChannel(Channel):
                 worker_exception = worker.exception()
                 if worker_exception:
                     self._fail_callback(session_id, exception = worker_exception, **kwargs)
+                else:
+                    self._success_callback(session_id, **kwargs)
             except CancelledError as e:
                 logger.info("Worker cancelled, session_id = {}".format(session_id))
             except Exception as e:

+ 5 - 4
channel/wechatmp/ServiceAccount.py

@@ -16,7 +16,7 @@ class Query():
 
     def POST(self):
         # Make sure to return the instance that first created, @singleton will do that. 
-        channel_instance = WechatMPChannel()
+        channel = WechatMPChannel()
         try:
             webData = web.data()
             # logger.debug("[wechatmp] Receive request:\n" + webData.decode("utf-8"))
@@ -27,14 +27,15 @@ class Query():
                 message_id = wechatmp_msg.msg_id
 
                 logger.info("[wechatmp] {}:{} Receive post query {} {}: {}".format(web.ctx.env.get('REMOTE_ADDR'), web.ctx.env.get('REMOTE_PORT'), from_user, message_id, message))
-                context = channel_instance._compose_context(ContextType.TEXT, message, isgroup=False, msg=wechatmp_msg)
+                context = channel._compose_context(ContextType.TEXT, message, isgroup=False, msg=wechatmp_msg)
                 if context:
                     # set private openai_api_key
                     # if from_user is not changed in itchat, this can be placed at chat_channel
                     user_data = conf().get_user_data(from_user)
                     context['openai_api_key'] = user_data.get('openai_api_key') # None or user openai_api_key
-                    channel_instance.produce(context)
-                # The reply will be sent by channel_instance.send() in another thread
+                    channel.produce(context)
+                    channel.running.add(from_user)
+                # The reply will be sent by channel.send() in another thread
                 return "success"
 
             elif wechatmp_msg.msg_type == 'event':

+ 18 - 11
channel/wechatmp/SubscribeAccount.py

@@ -41,7 +41,8 @@ class Query():
                     context = channel._compose_context(ContextType.TEXT, message, isgroup=False, msg=wechatmp_msg)
                     logger.debug("[wechatmp] context: {} {}".format(context, wechatmp_msg))
                     if message_id in channel.received_msgs: # received and finished
-                        return 
+                        # no return because of bandwords or other reasons
+                        return "success"
                     if supported and context:
                         # set private openai_api_key
                         # if from_user is not changed in itchat, this can be placed at chat_channel
@@ -71,11 +72,12 @@ class Query():
                     channel.query1[cache_key] = False
                     channel.query2[cache_key] = False
                     channel.query3[cache_key] = False
-                # Request again
+                # User request again, and the answer is not ready
                 elif cache_key in channel.running and channel.query1.get(cache_key) == True and channel.query2.get(cache_key) == True and channel.query3.get(cache_key) == True:
                     channel.query1[cache_key] = False  #To improve waiting experience, this can be set to True.
                     channel.query2[cache_key] = False  #To improve waiting experience, this can be set to True.
                     channel.query3[cache_key] = False
+                # User request again, and the answer is ready
                 elif cache_key in channel.cache_dict:
                     # Skip the waiting phase
                     channel.query1[cache_key] = True
@@ -89,7 +91,7 @@ class Query():
                     logger.debug("[wechatmp] query1 {}".format(cache_key))
                     channel.query1[cache_key] = True
                     cnt = 0
-                    while cache_key not in channel.cache_dict and cnt < 45:
+                    while cache_key in channel.running and cnt < 45:
                         cnt = cnt + 1
                         time.sleep(0.1)
                     if cnt == 45:
@@ -104,7 +106,7 @@ class Query():
                     logger.debug("[wechatmp] query2 {}".format(cache_key))
                     channel.query2[cache_key] = True
                     cnt = 0
-                    while cache_key not in channel.cache_dict and cnt < 45:
+                    while cache_key in channel.running and cnt < 45:
                         cnt = cnt + 1
                         time.sleep(0.1)
                     if cnt == 45:
@@ -119,7 +121,7 @@ class Query():
                     logger.debug("[wechatmp] query3 {}".format(cache_key))
                     channel.query3[cache_key] = True
                     cnt = 0
-                    while cache_key not in channel.cache_dict and cnt < 40:
+                    while cache_key in channel.running and cnt < 40:
                         cnt = cnt + 1
                         time.sleep(0.1)
                     if cnt == 40:
@@ -132,12 +134,17 @@ class Query():
                     else:
                         pass
 
-                if float(time.time()) - float(query_time) > 4.8:
-                    reply_text = "【正在思考中,回复任意文字尝试获取回复】"
-                    logger.info("[wechatmp] Timeout for {} {}, return".format(from_user, message_id))
-                    replyPost = reply.TextMsg(from_user, to_user, reply_text).send()
-                    return replyPost
-                
+
+                if cache_key not in channel.cache_dict and cache_key not in channel.running:
+                    # no return because of bandwords or other reasons
+                    return "success"
+
+                # if float(time.time()) - float(query_time) > 4.8:
+                #     reply_text = "【正在思考中,回复任意文字尝试获取回复】"
+                #     logger.info("[wechatmp] Timeout for {} {}, return".format(from_user, message_id))
+                #     replyPost = reply.TextMsg(from_user, to_user, reply_text).send()
+                #     return replyPost
+
                 if cache_key in channel.cache_dict:
                     content = channel.cache_dict[cache_key]
                     if len(content.encode('utf8'))<=MAX_UTF8_LEN:

+ 6 - 6
channel/wechatmp/wechatmp_channel.py

@@ -97,8 +97,7 @@ class WechatMPChannel(ChatChannel):
         if self.passive_reply:
             receiver = context["receiver"]
             self.cache_dict[receiver] = reply.content
-            self.running.remove(receiver)
-            logger.debug("[send] reply to {} saved to cache: {}".format(receiver, reply))
+            logger.info("[send] reply to {} saved to cache: {}".format(receiver, reply))
         else:
             receiver = context["receiver"]
             reply_text = reply.content
@@ -115,11 +114,12 @@ class WechatMPChannel(ChatChannel):
             logger.info("[send] Do send to {}: {}".format(receiver, reply_text))
         return
 
+    def _success_callback(self, session_id, **kwargs): # 线程正常结束时的回调函数
+        self.running.remove(session_id)
 
-    def _fail_callback(self, session_id, exception, context, **kwargs):
+    def _fail_callback(self, session_id, exception, context, **kwargs): # 线程异常结束时的回调函数
         logger.exception("[wechatmp] Fail to generation message to user, msgId={}, exception={}".format(context['msg'].msg_id, exception))
-        assert session_id not in self.cache_dict
+        if self.passive_reply:
+            assert session_id not in self.cache_dict
         self.running.remove(session_id)
 
-
-