فهرست منبع

refactor: reuse openai image interface

lanvent 3 سال پیش
والد
کامیت
721b36c7f7
3فایلهای تغییر یافته به همراه48 افزوده شده و 54 حذف شده
  1. 2 27
      bot/chatgpt/chat_gpt_bot.py
  2. 9 27
      bot/openai/open_ai_bot.py
  3. 37 0
      bot/openai/open_ai_image.py

+ 2 - 27
bot/chatgpt/chat_gpt_bot.py

@@ -1,6 +1,7 @@
 # encoding:utf-8
 # encoding:utf-8
 
 
 from bot.bot import Bot
 from bot.bot import Bot
+from bot.openai.open_ai_image import OpenAIImage
 from bridge.context import ContextType
 from bridge.context import ContextType
 from bridge.reply import Reply, ReplyType
 from bridge.reply import Reply, ReplyType
 from config import conf, load_config
 from config import conf, load_config
@@ -12,7 +13,7 @@ import time
 
 
 
 
 # OpenAI对话模型API (可用)
 # OpenAI对话模型API (可用)
-class ChatGPTBot(Bot):
+class ChatGPTBot(Bot,OpenAIImage):
     def __init__(self):
     def __init__(self):
         openai.api_key = conf().get('open_ai_api_key')
         openai.api_key = conf().get('open_ai_api_key')
         if conf().get('open_ai_api_base'):
         if conf().get('open_ai_api_base'):
@@ -23,8 +24,6 @@ class ChatGPTBot(Bot):
             openai.proxy = proxy
             openai.proxy = proxy
         if conf().get('rate_limit_chatgpt'):
         if conf().get('rate_limit_chatgpt'):
             self.tb4chatgpt = TokenBucket(conf().get('rate_limit_chatgpt', 20))
             self.tb4chatgpt = TokenBucket(conf().get('rate_limit_chatgpt', 20))
-        if conf().get('rate_limit_dalle'):
-            self.tb4dalle = TokenBucket(conf().get('rate_limit_dalle', 50))
 
 
     def reply(self, query, context=None):
     def reply(self, query, context=None):
         # acquire reply content
         # acquire reply content
@@ -128,30 +127,6 @@ class ChatGPTBot(Bot):
             self.sessions.clear_session(session_id)
             self.sessions.clear_session(session_id)
             return {"completion_tokens": 0, "content": "请再问我一次吧"}
             return {"completion_tokens": 0, "content": "请再问我一次吧"}
 
 
-    def create_img(self, query, retry_count=0):
-        try:
-            if conf().get('rate_limit_dalle') and not self.tb4dalle.get_token():
-                return False, "请求太快了,请休息一下再问我吧"
-            logger.info("[OPEN_AI] image_query={}".format(query))
-            response = openai.Image.create(
-                prompt=query,    #图片描述
-                n=1,             #每次生成图片的数量
-                size="256x256"   #图片大小,可选有 256x256, 512x512, 1024x1024
-            )
-            image_url = response['data'][0]['url']
-            logger.info("[OPEN_AI] image_url={}".format(image_url))
-            return True, image_url
-        except openai.error.RateLimitError as e:
-            logger.warn(e)
-            if retry_count < 1:
-                time.sleep(5)
-                logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1))
-                return self.create_img(query, retry_count+1)
-            else:
-                return False, "提问太快啦,请休息一下再问我吧"
-        except Exception as e:
-            logger.exception(e)
-            return False, str(e)
 
 
 
 
 class AzureChatGPTBot(ChatGPTBot):
 class AzureChatGPTBot(ChatGPTBot):

+ 9 - 27
bot/openai/open_ai_bot.py

@@ -1,6 +1,7 @@
 # encoding:utf-8
 # encoding:utf-8
 
 
 from bot.bot import Bot
 from bot.bot import Bot
+from bot.openai.open_ai_image import OpenAIImage
 from bridge.context import ContextType
 from bridge.context import ContextType
 from bridge.reply import Reply, ReplyType
 from bridge.reply import Reply, ReplyType
 from config import conf
 from config import conf
@@ -11,7 +12,7 @@ import time
 user_session = dict()
 user_session = dict()
 
 
 # OpenAI对话模型API (可用)
 # OpenAI对话模型API (可用)
-class OpenAIBot(Bot):
+class OpenAIBot(Bot, OpenAIImage):
     def __init__(self):
     def __init__(self):
         openai.api_key = conf().get('open_ai_api_key')
         openai.api_key = conf().get('open_ai_api_key')
         if conf().get('open_ai_api_base'):
         if conf().get('open_ai_api_base'):
@@ -45,7 +46,13 @@ class OpenAIBot(Bot):
                     reply = Reply(ReplyType.TEXT, reply_content)
                     reply = Reply(ReplyType.TEXT, reply_content)
                 return reply
                 return reply
             elif context.type == ContextType.IMAGE_CREATE:
             elif context.type == ContextType.IMAGE_CREATE:
-                return self.create_img(query, 0)
+                ok, retstring = self.create_img(query, 0)
+                reply = None
+                if ok:
+                    reply = Reply(ReplyType.IMAGE_URL, retstring)
+                else:
+                    reply = Reply(ReplyType.ERROR, retstring)
+                return reply
 
 
     def reply_text(self, query, user_id, retry_count=0):
     def reply_text(self, query, user_id, retry_count=0):
         try:
         try:
@@ -77,31 +84,6 @@ class OpenAIBot(Bot):
             Session.clear_session(user_id)
             Session.clear_session(user_id)
             return "请再问我一次吧"
             return "请再问我一次吧"
 
 
-
-    def create_img(self, query, retry_count=0):
-        try:
-            logger.info("[OPEN_AI] image_query={}".format(query))
-            response = openai.Image.create(
-                prompt=query,    #图片描述
-                n=1,             #每次生成图片的数量
-                size="256x256"   #图片大小,可选有 256x256, 512x512, 1024x1024
-            )
-            image_url = response['data'][0]['url']
-            logger.info("[OPEN_AI] image_url={}".format(image_url))
-            return image_url
-        except openai.error.RateLimitError as e:
-            logger.warn(e)
-            if retry_count < 1:
-                time.sleep(5)
-                logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1))
-                return self.reply_text(query, retry_count+1)
-            else:
-                return "提问太快啦,请休息一下再问我吧"
-        except Exception as e:
-            logger.exception(e)
-            return None
-
-
 class Session(object):
 class Session(object):
     @staticmethod
     @staticmethod
     def build_session_query(query, user_id):
     def build_session_query(query, user_id):

+ 37 - 0
bot/openai/open_ai_image.py

@@ -0,0 +1,37 @@
+import time
+import openai
+from common.token_bucket import TokenBucket
+from common.log import logger
+from config import conf
+
+# OPENAI提供的画图接口
+class OpenAIImage(object):
+    def __init__(self):
+        openai.api_key = conf().get('open_ai_api_key')
+        if conf().get('rate_limit_dalle'):
+            self.tb4dalle = TokenBucket(conf().get('rate_limit_dalle', 50))
+            
+    def create_img(self, query, retry_count=0):
+        try:
+            if conf().get('rate_limit_dalle') and not self.tb4dalle.get_token():
+                return False, "请求太快了,请休息一下再问我吧"
+            logger.info("[OPEN_AI] image_query={}".format(query))
+            response = openai.Image.create(
+                prompt=query,    #图片描述
+                n=1,             #每次生成图片的数量
+                size="256x256"   #图片大小,可选有 256x256, 512x512, 1024x1024
+            )
+            image_url = response['data'][0]['url']
+            logger.info("[OPEN_AI] image_url={}".format(image_url))
+            return True, image_url
+        except openai.error.RateLimitError as e:
+            logger.warn(e)
+            if retry_count < 1:
+                time.sleep(5)
+                logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1))
+                return self.create_img(query, retry_count+1)
+            else:
+                return False, "提问太快啦,请休息一下再问我吧"
+        except Exception as e:
+            logger.exception(e)
+            return False, str(e)