Explorar el Código

feat: add support for azure dalle

lanvent hace 2 años
padre
commit
3314b05648
Se han modificado 2 ficheros con 27 adiciones y 1 borrados
  1. 25 0
      bot/chatgpt/chat_gpt_bot.py
  2. 2 1
      bot/openai/open_ai_image.py

+ 25 - 0
bot/chatgpt/chat_gpt_bot.py

@@ -4,6 +4,7 @@ import time
 
 import openai
 import openai.error
+import requests
 
 from bot.bot import Bot
 from bot.chatgpt.chat_gpt_session import ChatGPTSession
@@ -155,3 +156,27 @@ class AzureChatGPTBot(ChatGPTBot):
         openai.api_type = "azure"
         openai.api_version = "2023-03-15-preview"
         self.args["deployment_id"] = conf().get("azure_deployment_id")
+
+    def create_img(self, query, retry_count=0, api_key=None):
+        api_base = "https://a-wxf.openai.azure.com/"
+        api_version = "2022-08-03-preview"
+        url = "{}dalle/text-to-image?api-version={}".format(api_base, api_version)
+        api_key = api_key or openai.api_key
+        headers = {"api-key": api_key, "Content-Type": "application/json"}
+        try:
+            body = {"caption": query, "resolution": conf().get("image_create_size", "256x256")}
+            submission = requests.post(url, headers=headers, json=body)
+            operation_location = submission.headers["Operation-Location"]
+            retry_after = submission.headers["Retry-after"]
+            status = ""
+            image_url = ""
+            while status != "Succeeded":
+                logger.info("waiting for image create..., " + status + ",retry after " + retry_after + " seconds")
+                time.sleep(int(retry_after))
+                response = requests.get(operation_location, headers=headers)
+                status = response.json()["status"]
+                image_url = response.json()["result"]["contentUrl"]
+            return True, image_url
+        except Exception as e:
+            logger.error("create image error: {}".format(e))
+            return False, "图片生成失败"

+ 2 - 1
bot/openai/open_ai_image.py

@@ -15,12 +15,13 @@ class OpenAIImage(object):
         if conf().get("rate_limit_dalle"):
             self.tb4dalle = TokenBucket(conf().get("rate_limit_dalle", 50))
 
-    def create_img(self, query, retry_count=0):
+    def create_img(self, query, retry_count=0, api_key=None):
         try:
             if conf().get("rate_limit_dalle") and not self.tb4dalle.get_token():
                 return False, "请求太快了,请休息一下再问我吧"
             logger.info("[OPEN_AI] image_query={}".format(query))
             response = openai.Image.create(
+                api_key=api_key,
                 prompt=query,  # 图片描述
                 n=1,  # 每次生成图片的数量
                 size=conf().get("image_create_size", "256x256"),  # 图片大小,可选有 256x256, 512x512, 1024x1024