Explorar o código

wechatmp: logic simplification

JS00000 %!s(int64=3) %!d(string=hai) anos
pai
achega
df4c1f0401
Modificáronse 2 ficheiros con 82 adicións e 135 borrados
  1. 67 120
      channel/wechatmp/subscribe_account.py
  2. 15 15
      channel/wechatmp/wechatmp_channel.py

+ 67 - 120
channel/wechatmp/subscribe_account.py

@@ -17,12 +17,11 @@ class Query:
         return verify_server(web.input())
 
     def POST(self):
-        # Make sure to return the instance that first created, @singleton will do that.
-        channel = WechatMPChannel()
         try:
-            query_time = time.time()
+            request_time = time.time()
+            channel = WechatMPChannel()
             webData = web.data()
-            logger.debug("[wechatmp] Receive request:\n" + webData.decode("utf-8"))
+            logger.debug("[wechatmp] Receive post data:\n" + webData.decode("utf-8"))
             wechatmp_msg = receive.parse_xml(webData)
             if wechatmp_msg.msg_type == "text" or wechatmp_msg.msg_type == "voice":
                 from_user = wechatmp_msg.from_user_id
@@ -30,48 +29,34 @@ class Query:
                 message = wechatmp_msg.content.decode("utf-8")
                 message_id = wechatmp_msg.msg_id
 
-                logger.info(
-                    "[wechatmp] {}:{} Receive post query {} {}: {}".format(
-                        web.ctx.env.get("REMOTE_ADDR"),
-                        web.ctx.env.get("REMOTE_PORT"),
-                        from_user,
-                        message_id,
-                        message,
-                    )
-                )
                 supported = True
                 if "【收到不支持的消息类型,暂无法显示】" in message:
                     supported = False  # not supported, used to refresh
-                cache_key = from_user
 
-                reply_text = ""
                 # New request
                 if (
-                    cache_key not in channel.cache_dict
-                    and cache_key not in channel.running
+                    from_user not in channel.cache_dict
+                    and from_user not in channel.running
+                    or message.startswith("#") 
+                    and message_id not in channel.request_cnt # insert the godcmd
                 ):
-                    # The first query begin, reset the cache
+                    # The first query begin
                     context = channel._compose_context(
                         ContextType.TEXT, message, isgroup=False, msg=wechatmp_msg
                     )
                     logger.debug(
                         "[wechatmp] context: {} {}".format(context, wechatmp_msg)
                     )
-                    if message_id in channel.received_msgs:  # received and finished
-                        # no return because of bandwords or other reasons
-                        return "success"
+
                     if supported and context:
                         # set private openai_api_key
                         # if from_user is not changed in itchat, this can be placed at chat_channel
                         user_data = conf().get_user_data(from_user)
-                        context["openai_api_key"] = user_data.get(
-                            "openai_api_key"
-                        )  # None or user openai_api_key
-                        channel.received_msgs[message_id] = wechatmp_msg
-                        channel.running.add(cache_key)
+                        context["openai_api_key"] = user_data.get("openai_api_key")
+                        channel.running.add(from_user)
                         channel.produce(context)
                     else:
-                        trigger_prefix = conf().get("single_chat_prefix", [""])[0]
+                        trigger_prefix = conf().get("single_chat_prefix", [""])
                         if trigger_prefix or not supported:
                             if trigger_prefix:
                                 content = textwrap.dedent(
@@ -92,108 +77,67 @@ class Query:
                                 """\
                                 未知错误,请稍后再试"""
                             )
-                        replyMsg = reply.TextMsg(
-                            wechatmp_msg.from_user_id, wechatmp_msg.to_user_id, content
-                        )
-                        return replyMsg.send()
-                    channel.query1[cache_key] = False
-                    channel.query2[cache_key] = False
-                    channel.query3[cache_key] = False
-                # User request again, and the answer is not ready
-                elif (
-                    cache_key in channel.running
-                    and channel.query1.get(cache_key) == True
-                    and channel.query2.get(cache_key) == True
-                    and channel.query3.get(cache_key) == True
-                ):
-                    channel.query1[
-                        cache_key
-                    ] = False  # To improve waiting experience, this can be set to True.
-                    channel.query2[
-                        cache_key
-                    ] = False  # To improve waiting experience, this can be set to True.
-                    channel.query3[cache_key] = False
-                # User request again, and the answer is ready
-                elif cache_key in channel.cache_dict:
-                    # Skip the waiting phase
-                    channel.query1[cache_key] = True
-                    channel.query2[cache_key] = True
-                    channel.query3[cache_key] = True
-
-                assert not (
-                    cache_key in channel.cache_dict and cache_key in channel.running
+                        replyPost = reply.TextMsg(wechatmp_msg.from_user_id, wechatmp_msg.to_user_id, content).send()
+                        return replyPost
+
+
+                # Wechat official server will request 3 times (5 seconds each), with the same message_id.
+                # Because the interval is 5 seconds, here assumed that do not have multithreading problems.
+                request_cnt = channel.request_cnt.get(message_id, 0) + 1
+                channel.request_cnt[message_id] = request_cnt
+                logger.info(
+                    "[wechatmp] Request {} from {} {}\n{}\n{}:{}".format(
+                        request_cnt,
+                        from_user,
+                        message_id,
+                        message,
+                        web.ctx.env.get("REMOTE_ADDR"),
+                        web.ctx.env.get("REMOTE_PORT"),
+                    )
                 )
 
-                if channel.query1.get(cache_key) == False:
-                    # The first query from wechat official server
-                    logger.debug("[wechatmp] query1 {}".format(cache_key))
-                    channel.query1[cache_key] = True
-                    cnt = 0
-                    while cache_key in channel.running and cnt < 45:
-                        cnt = cnt + 1
+                task_running = True
+                waiting_until = request_time + 4
+                while time.time() < waiting_until:
+                    if from_user in channel.running:
                         time.sleep(0.1)
-                    if cnt == 45:
-                        # waiting for timeout (the POST query will be closed by wechat official server)
-                        time.sleep(1)
-                        # and do nothing
-                        return
                     else:
-                        pass
-                elif channel.query2.get(cache_key) == False:
-                    # The second query from wechat official server
-                    logger.debug("[wechatmp] query2 {}".format(cache_key))
-                    channel.query2[cache_key] = True
-                    cnt = 0
-                    while cache_key in channel.running and cnt < 45:
-                        cnt = cnt + 1
-                        time.sleep(0.1)
-                    if cnt == 45:
-                        # waiting for timeout (the POST query will be closed by wechat official server)
-                        time.sleep(1)
-                        # and do nothing
-                        return
-                    else:
-                        pass
-                elif channel.query3.get(cache_key) == False:
-                    # The third query from wechat official server
-                    logger.debug("[wechatmp] query3 {}".format(cache_key))
-                    channel.query3[cache_key] = True
-                    cnt = 0
-                    while cache_key in channel.running and cnt < 40:
-                        cnt = cnt + 1
-                        time.sleep(0.1)
-                    if cnt == 40:
-                        # Have waiting for 3x5 seconds
+                        task_running = False
+                        break
+
+                reply_text = ""
+                if task_running:
+                    if request_cnt < 3:
+                        # waiting for timeout (the POST request will be closed by Wechat official server)
+                        time.sleep(2)
+                        # and do nothing, waiting for the next request
+                        return "success"
+                    else: # request_cnt == 3:
                         # return timeout message
                         reply_text = "【正在思考中,回复任意文字尝试获取回复】"
-                        logger.info(
-                            "[wechatmp] Three queries has finished For {}: {}".format(
-                                from_user, message_id
-                            )
-                        )
-                        replyPost = reply.TextMsg(from_user, to_user, reply_text).send()
-                        return replyPost
-                    else:
-                        pass
+                        # replyPost = reply.TextMsg(from_user, to_user, reply_text).send()
+                        # return replyPost
+
+                # reply or reply_text is ready
+                channel.request_cnt.pop(message_id)
 
+                # no return because of bandwords or other reasons
                 if (
-                    cache_key not in channel.cache_dict
-                    and cache_key not in channel.running
+                    from_user not in channel.cache_dict
+                    and from_user not in channel.running
                 ):
-                    # no return because of bandwords or other reasons
                     return "success"
 
-                # if float(time.time()) - float(query_time) > 4.8:
-                #     reply_text = "【正在思考中,回复任意文字尝试获取回复】"
-                #     logger.info("[wechatmp] Timeout for {} {}, return".format(from_user, message_id))
-                #     replyPost = reply.TextMsg(from_user, to_user, reply_text).send()
-                #     return replyPost
+                # reply is ready
+                if from_user in channel.cache_dict:
+                    # Only one message thread can access to the cached data
+                    try:
+                        content = channel.cache_dict.pop(from_user)
+                    except KeyError:
+                        return "success"
 
-                if cache_key in channel.cache_dict:
-                    content = channel.cache_dict[cache_key]
                     if len(content.encode("utf8")) <= MAX_UTF8_LEN:
-                        reply_text = channel.cache_dict[cache_key]
-                        channel.cache_dict.pop(cache_key)
+                        reply_text = content
                     else:
                         continue_text = "\n【未完待续,回复任意文字以继续】"
                         splits = split_string_by_utf8_length(
@@ -202,11 +146,14 @@ class Query:
                             max_split=1,
                         )
                         reply_text = splits[0] + continue_text
-                        channel.cache_dict[cache_key] = splits[1]
+                        channel.cache_dict[from_user] = splits[1]
+
                 logger.info(
-                    "[wechatmp] {}:{} Do send {}".format(
-                        web.ctx.env.get("REMOTE_ADDR"),
-                        web.ctx.env.get("REMOTE_PORT"),
+                    "[wechatmp] Request {} do send to {} {}: {}\n{}".format(
+                        request_cnt,
+                        from_user,
+                        message_id,
+                        message,
                         reply_text,
                     )
                 )

+ 15 - 15
channel/wechatmp/wechatmp_channel.py

@@ -1,5 +1,4 @@
 # -*- coding: utf-8 -*-
-import web
 import io
 import imghdr
 import requests
@@ -8,18 +7,17 @@ from bridge.reply import *
 from channel.chat_channel import ChatChannel
 from channel.wechatmp.wechatmp_client import WechatMPClient
 from channel.wechatmp.common import *
-from common.expired_dict import ExpiredDict
 from common.log import logger
-from common.tmp_dir import TmpDir
 from common.singleton import singleton
 from config import conf
 
+import web
 # If using SSL, uncomment the following lines, and modify the certificate path.
-from cheroot.server import HTTPServer
-from cheroot.ssl.builtin import BuiltinSSLAdapter
-HTTPServer.ssl_adapter = BuiltinSSLAdapter(
-        certificate='/ssl/cert.pem',
-        private_key='/ssl/cert.key')
+# from cheroot.server import HTTPServer
+# from cheroot.ssl.builtin import BuiltinSSLAdapter
+# HTTPServer.ssl_adapter = BuiltinSSLAdapter(
+#         certificate='/ssl/cert.pem',
+#         private_key='/ssl/cert.key')
 
 
 @singleton
@@ -27,15 +25,17 @@ class WechatMPChannel(ChatChannel):
     def __init__(self, passive_reply=True):
         super().__init__()
         self.passive_reply = passive_reply
-        self.running = set()
-        self.received_msgs = ExpiredDict(60 * 60 * 24)
+        self.flag = 0
+
         self.client = WechatMPClient()
         if self.passive_reply:
             self.NOT_SUPPORT_REPLYTYPE = [ReplyType.IMAGE, ReplyType.VOICE]
+            # Cache the reply to the user's first message
             self.cache_dict = dict()
-            self.query1 = dict()
-            self.query2 = dict()
-            self.query3 = dict()
+            # Record whether the current message is being processed
+            self.running = set()
+            # Count the request from wechat official server by message_id
+            self.request_cnt = dict()
         else:
             self.NOT_SUPPORT_REPLYTYPE = []
 
@@ -53,8 +53,8 @@ class WechatMPChannel(ChatChannel):
     def send(self, reply: Reply, context: Context):
         receiver = context["receiver"]
         if self.passive_reply:
+            logger.info("[wechatmp] reply to {} cached:\n{}".format(receiver, reply))
             self.cache_dict[receiver] = reply.content
-            logger.info("[wechatmp] reply cached reply to {}: {}".format(receiver, reply))
         else:
             if reply.type == ReplyType.TEXT or reply.type == ReplyType.INFO or reply.type == ReplyType.ERROR:
                 reply_text = reply.content
@@ -64,7 +64,6 @@ class WechatMPChannel(ChatChannel):
             elif reply.type == ReplyType.VOICE:
                 voice_file_path = reply.content
                 logger.info("[wechatmp] voice file path {}".format(voice_file_path))
-
                 with open(voice_file_path, 'rb') as f:
                     filename = receiver + "-" + context["msg"].msg_id + ".mp3"
                     media_id = self.client.upload_media("voice", (filename, f, "audio/mpeg"))
@@ -86,6 +85,7 @@ class WechatMPChannel(ChatChannel):
                 media_id = self.client.upload_media("image", (filename, image_storage, content_type))
                 self.client.send_image(receiver, media_id)
                 logger.info("[wechatmp] sendImage url={}, receiver={}".format(img_url, receiver))
+
             elif reply.type == ReplyType.IMAGE:  # 从文件读取图片
                 image_storage = reply.content
                 image_storage.seek(0)