|
|
@@ -3,12 +3,8 @@
|
|
|
import time
|
|
|
|
|
|
import openai
|
|
|
-from openai import AzureOpenAI
|
|
|
-
|
|
|
-client = AzureOpenAI(api_key=conf().get("open_ai_api_key"),
|
|
|
-api_version=conf().get("azure_api_version", "2023-06-01-preview"))
|
|
|
-import openai.error
|
|
|
import requests
|
|
|
+from openai import OpenAI
|
|
|
|
|
|
from bot.bot import Bot
|
|
|
from bot.chatgpt.chat_gpt_session import ChatGPTSession
|
|
|
@@ -20,6 +16,8 @@ from common.log import logger
|
|
|
from common.token_bucket import TokenBucket
|
|
|
from config import conf, load_config
|
|
|
|
|
|
+client = OpenAI(api_key=conf().get("open_ai_api_key"),
|
|
|
+ base_url=conf().get("open_ai_api_base", "https://api.openai.com/v1"))
|
|
|
|
|
|
# OpenAI对话模型API (可用)
|
|
|
class ChatGPTBot(Bot, OpenAIImage):
|
|
|
@@ -28,11 +26,11 @@ class ChatGPTBot(Bot, OpenAIImage):
|
|
|
# set the default api_key
|
|
|
if conf().get("open_ai_api_base"):
|
|
|
# TODO: The 'openai.api_base' option isn't read in the client API. You will need to pass it when you instantiate the client, e.g. 'OpenAI(base_url=conf().get("open_ai_api_base"))'
|
|
|
- # openai.api_base = conf().get("open_ai_api_base")
|
|
|
+ openai.base_url = conf().get("open_ai_api_base")
|
|
|
proxy = conf().get("proxy")
|
|
|
if proxy:
|
|
|
# TODO: The 'openai.proxy' option isn't read in the client API. You will need to pass it when you instantiate the client, e.g. 'OpenAI(proxy=proxy)'
|
|
|
- # openai.proxy = proxy
|
|
|
+ openai.proxy = proxy
|
|
|
if conf().get("rate_limit_chatgpt"):
|
|
|
self.tb4chatgpt = TokenBucket(conf().get("rate_limit_chatgpt", 20))
|
|
|
|
|
|
@@ -125,7 +123,7 @@ class ChatGPTBot(Bot, OpenAIImage):
|
|
|
# if api_key == None, the default openai.api_key will be used
|
|
|
if args is None:
|
|
|
args = self.args
|
|
|
- response = client.chat.completions.create(api_key=api_key, messages=session.messages, **args)
|
|
|
+ response = client.chat.completions.create(messages=session.messages, **args)
|
|
|
# logger.debug("[CHATGPT] response={}".format(response))
|
|
|
# logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
|
|
|
return {
|
|
|
@@ -173,9 +171,10 @@ class AzureChatGPTBot(ChatGPTBot):
|
|
|
super().__init__()
|
|
|
self.args["deployment_id"] = conf().get("azure_deployment_id")
|
|
|
|
|
|
- def create_img(self, query, retry_count=0, api_key=None):
|
|
|
+ def create_img(self, query, retry_count=0, api_key=None, **kwargs):
|
|
|
+ # TODO 升级一下这个version
|
|
|
api_version = "2022-08-03-preview"
|
|
|
- url = "{}dalle/text-to-image?api-version={}".format(openai.api_base, api_version)
|
|
|
+ url = "{}dalle/text-to-image?api-version={}".format(openai.base_url, api_version)
|
|
|
api_key = api_key or openai.api_key
|
|
|
headers = {"api-key": api_key, "Content-Type": "application/json"}
|
|
|
try:
|