|
|
@@ -3,6 +3,7 @@
|
|
|
from bot.bot import Bot
|
|
|
from config import conf, load_config
|
|
|
from common.log import logger
|
|
|
+from common.token_bucket import TokenBucket
|
|
|
from common.expired_dict import ExpiredDict
|
|
|
import openai
|
|
|
import time
|
|
|
@@ -21,6 +22,10 @@ class ChatGPTBot(Bot):
|
|
|
proxy = conf().get('proxy')
|
|
|
if proxy:
|
|
|
openai.proxy = proxy
|
|
|
+ if conf().get('rate_limit_chatgpt'):
|
|
|
+ self.tb4chatgpt = TokenBucket(conf().get('rate_limit_chatgpt', 20))
|
|
|
+ if conf().get('rate_limit_dalle'):
|
|
|
+ self.tb4dalle = TokenBucket(conf().get('rate_limit_dalle', 50))
|
|
|
|
|
|
def reply(self, query, context=None):
|
|
|
# acquire reply content
|
|
|
@@ -63,6 +68,8 @@ class ChatGPTBot(Bot):
|
|
|
:return: {}
|
|
|
'''
|
|
|
try:
|
|
|
+ if conf().get('rate_limit_chatgpt') and not self.tb4chatgpt.get_token():
|
|
|
+ return {"completion_tokens": 0, "content": "提问太快啦,请休息一下再问我吧"}
|
|
|
response = openai.ChatCompletion.create(
|
|
|
model= conf().get("model") or "gpt-3.5-turbo", # 对话模型的名称
|
|
|
messages=session,
|
|
|
@@ -102,6 +109,8 @@ class ChatGPTBot(Bot):
|
|
|
|
|
|
def create_img(self, query, retry_count=0):
|
|
|
try:
|
|
|
+ if conf().get('rate_limit_dalle') and not self.tb4dalle.get_token():
|
|
|
+ return "请求太快了,请休息一下再问我吧"
|
|
|
logger.info("[OPEN_AI] image_query={}".format(query))
|
|
|
response = openai.Image.create(
|
|
|
prompt=query, #图片描述
|
|
|
@@ -118,7 +127,7 @@ class ChatGPTBot(Bot):
|
|
|
logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1))
|
|
|
return self.create_img(query, retry_count+1)
|
|
|
else:
|
|
|
- return "提问太快啦,请休息一下再问我吧"
|
|
|
+ return "请求太快啦,请休息一下再问我吧"
|
|
|
except Exception as e:
|
|
|
logger.exception(e)
|
|
|
return None
|