|
@@ -12,16 +12,18 @@ from config import conf
|
|
|
class OpenAIImage(object):
|
|
class OpenAIImage(object):
|
|
|
def __init__(self):
|
|
def __init__(self):
|
|
|
openai.api_key = conf().get("open_ai_api_key")
|
|
openai.api_key = conf().get("open_ai_api_key")
|
|
|
|
|
+ openai.api_base = conf().get("open_ai_api_base")
|
|
|
if conf().get("rate_limit_dalle"):
|
|
if conf().get("rate_limit_dalle"):
|
|
|
self.tb4dalle = TokenBucket(conf().get("rate_limit_dalle", 50))
|
|
self.tb4dalle = TokenBucket(conf().get("rate_limit_dalle", 50))
|
|
|
|
|
|
|
|
- def create_img(self, query, retry_count=0, api_key=None):
|
|
|
|
|
|
|
+ def create_img(self, query, retry_count=0, api_key=None, api_base=None):
|
|
|
try:
|
|
try:
|
|
|
if conf().get("rate_limit_dalle") and not self.tb4dalle.get_token():
|
|
if conf().get("rate_limit_dalle") and not self.tb4dalle.get_token():
|
|
|
return False, "请求太快了,请休息一下再问我吧"
|
|
return False, "请求太快了,请休息一下再问我吧"
|
|
|
logger.info("[OPEN_AI] image_query={}".format(query))
|
|
logger.info("[OPEN_AI] image_query={}".format(query))
|
|
|
response = openai.Image.create(
|
|
response = openai.Image.create(
|
|
|
api_key=api_key,
|
|
api_key=api_key,
|
|
|
|
|
+ api_base=api_base,
|
|
|
prompt=query, # 图片描述
|
|
prompt=query, # 图片描述
|
|
|
n=1, # 每次生成图片的数量
|
|
n=1, # 每次生成图片的数量
|
|
|
model=conf().get("text_to_image") or "dall-e-2",
|
|
model=conf().get("text_to_image") or "dall-e-2",
|