Просмотр исходного кода

fix: ensure get access_token thread-safe

lanvent 2 лет назад
Родитель
Сommit
c6601aaeed

+ 1 - 1
channel/wechatcom/README.md

@@ -54,4 +54,4 @@
 
 AIGC开放社区中已经部署了多个可免费使用的Bot,扫描下方的二维码会自动邀请你来体验。
 
-<img width="360" src="./docs/images/aigcopen.png">
+<img width="200" src="../../docs/images/aigcopen.png">

+ 3 - 3
channel/wechatcom/wechatcomapp_channel.py

@@ -1,4 +1,3 @@
-#!/usr/bin/env python
 # -*- coding=utf-8 -*-
 import io
 import os
@@ -6,7 +5,7 @@ import textwrap
 
 import requests
 import web
-from wechatpy.enterprise import WeChatClient, create_reply, parse_message
+from wechatpy.enterprise import create_reply, parse_message
 from wechatpy.enterprise.crypto import WeChatCrypto
 from wechatpy.enterprise.exceptions import InvalidCorpIdException
 from wechatpy.exceptions import InvalidSignatureException, WeChatClientException
@@ -14,6 +13,7 @@ from wechatpy.exceptions import InvalidSignatureException, WeChatClientException
 from bridge.context import Context
 from bridge.reply import Reply, ReplyType
 from channel.chat_channel import ChatChannel
+from channel.wechatcom.wechatcomapp_client import WechatComAppClient
 from channel.wechatcom.wechatcomapp_message import WechatComAppMessage
 from common.log import logger
 from common.singleton import singleton
@@ -38,7 +38,7 @@ class WechatComAppChannel(ChatChannel):
             "[wechatcom] init: corp_id: {}, secret: {}, agent_id: {}, token: {}, aes_key: {}".format(self.corp_id, self.secret, self.agent_id, self.token, self.aes_key)
         )
         self.crypto = WeChatCrypto(self.token, self.aes_key, self.corp_id)
-        self.client = WeChatClient(self.corp_id, self.secret)  # todo: 这里可能有线程安全问题
+        self.client = WechatComAppClient(self.corp_id, self.secret)
 
     def startup(self):
         # start message listener

+ 21 - 0
channel/wechatcom/wechatcomapp_client.py

@@ -0,0 +1,21 @@
+import threading
+import time
+
+from wechatpy.enterprise import WeChatClient
+
+
+class WechatComAppClient(WeChatClient):
+    def __init__(self, corp_id, secret, access_token=None, session=None, timeout=None, auto_retry=True):
+        super(WechatComAppClient, self).__init__(corp_id, secret, access_token, session, timeout, auto_retry)
+        self.fetch_access_token_lock = threading.Lock()
+
+    def fetch_access_token(self):  # 重载父类方法,加锁避免多线程重复获取access_token
+        with self.fetch_access_token_lock:
+            access_token = self.session.get(self.access_token_key)
+            if access_token:
+                if not self.expires_at:
+                    return access_token
+                timestamp = time.time()
+                if self.expires_at - timestamp > 60:
+                    return access_token
+            return super().fetch_access_token()

+ 4 - 17
channel/wechatcom/wechatcomapp_message.py

@@ -1,14 +1,9 @@
-import re
-
-import requests
 from wechatpy.enterprise import WeChatClient
 
 from bridge.context import ContextType
 from channel.chat_message import ChatMessage
 from common.log import logger
 from common.tmp_dir import TmpDir
-from lib import itchat
-from lib.itchat.content import *
 
 
 class WechatComAppMessage(ChatMessage):
@@ -23,9 +18,7 @@ class WechatComAppMessage(ChatMessage):
             self.content = msg.content
         elif msg.type == "voice":
             self.ctype = ContextType.VOICE
-            self.content = (
-                TmpDir().path() + msg.media_id + "." + msg.format
-            )  # content直接存临时目录路径
+            self.content = TmpDir().path() + msg.media_id + "." + msg.format  # content直接存临时目录路径
 
             def download_voice():
                 # 如果响应状态码是200,则将响应内容写入本地文件
@@ -34,9 +27,7 @@ class WechatComAppMessage(ChatMessage):
                     with open(self.content, "wb") as f:
                         f.write(response.content)
                 else:
-                    logger.info(
-                        f"[wechatcom] Failed to download voice file, {response.content}"
-                    )
+                    logger.info(f"[wechatcom] Failed to download voice file, {response.content}")
 
             self._prepare_fn = download_voice
         elif msg.type == "image":
@@ -50,15 +41,11 @@ class WechatComAppMessage(ChatMessage):
                     with open(self.content, "wb") as f:
                         f.write(response.content)
                 else:
-                    logger.info(
-                        f"[wechatcom] Failed to download image file, {response.content}"
-                    )
+                    logger.info(f"[wechatcom] Failed to download image file, {response.content}")
 
             self._prepare_fn = download_image
         else:
-            raise NotImplementedError(
-                "Unsupported message type: Type:{} ".format(msg.type)
-            )
+            raise NotImplementedError("Unsupported message type: Type:{} ".format(msg.type))
 
         self.from_user_id = msg.source
         self.to_user_id = msg.target