|
|
@@ -3,6 +3,10 @@
|
|
|
import time
|
|
|
|
|
|
import openai
|
|
|
+from openai import AzureOpenAI
|
|
|
+
|
|
|
+client = AzureOpenAI(api_key=conf().get("open_ai_api_key"),
|
|
|
+api_version=conf().get("azure_api_version", "2023-06-01-preview"))
|
|
|
import openai.error
|
|
|
import requests
|
|
|
|
|
|
@@ -22,12 +26,13 @@ class ChatGPTBot(Bot, OpenAIImage):
|
|
|
def __init__(self):
|
|
|
super().__init__()
|
|
|
# set the default api_key
|
|
|
- openai.api_key = conf().get("open_ai_api_key")
|
|
|
if conf().get("open_ai_api_base"):
|
|
|
- openai.api_base = conf().get("open_ai_api_base")
|
|
|
+ # TODO: The 'openai.api_base' option isn't read in the client API. You will need to pass it when you instantiate the client, e.g. 'OpenAI(base_url=conf().get("open_ai_api_base"))'
|
|
|
+ # openai.api_base = conf().get("open_ai_api_base")
|
|
|
proxy = conf().get("proxy")
|
|
|
if proxy:
|
|
|
- openai.proxy = proxy
|
|
|
+ # TODO: The 'openai.proxy' option isn't read in the client API. You will need to pass it when you instantiate the client, e.g. 'OpenAI(proxy=proxy)'
|
|
|
+ # openai.proxy = proxy
|
|
|
if conf().get("rate_limit_chatgpt"):
|
|
|
self.tb4chatgpt = TokenBucket(conf().get("rate_limit_chatgpt", 20))
|
|
|
|
|
|
@@ -116,37 +121,37 @@ class ChatGPTBot(Bot, OpenAIImage):
|
|
|
"""
|
|
|
try:
|
|
|
if conf().get("rate_limit_chatgpt") and not self.tb4chatgpt.get_token():
|
|
|
- raise openai.error.RateLimitError("RateLimitError: rate limit exceeded")
|
|
|
+ raise openai.RateLimitError("RateLimitError: rate limit exceeded")
|
|
|
# if api_key == None, the default openai.api_key will be used
|
|
|
if args is None:
|
|
|
args = self.args
|
|
|
- response = openai.ChatCompletion.create(api_key=api_key, messages=session.messages, **args)
|
|
|
+ response = client.chat.completions.create(api_key=api_key, messages=session.messages, **args)
|
|
|
# logger.debug("[CHATGPT] response={}".format(response))
|
|
|
# logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
|
|
|
return {
|
|
|
- "total_tokens": response["usage"]["total_tokens"],
|
|
|
- "completion_tokens": response["usage"]["completion_tokens"],
|
|
|
- "content": response.choices[0]["message"]["content"],
|
|
|
+ "total_tokens": response.usage.total_tokens,
|
|
|
+ "completion_tokens": response.usage.completion_tokens,
|
|
|
+ "content": response.choices[0].message.content,
|
|
|
}
|
|
|
except Exception as e:
|
|
|
need_retry = retry_count < 2
|
|
|
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
|
|
|
- if isinstance(e, openai.error.RateLimitError):
|
|
|
+ if isinstance(e, openai.RateLimitError):
|
|
|
logger.warn("[CHATGPT] RateLimitError: {}".format(e))
|
|
|
result["content"] = "提问太快啦,请休息一下再问我吧"
|
|
|
if need_retry:
|
|
|
time.sleep(20)
|
|
|
- elif isinstance(e, openai.error.Timeout):
|
|
|
+ elif isinstance(e, openai.Timeout):
|
|
|
logger.warn("[CHATGPT] Timeout: {}".format(e))
|
|
|
result["content"] = "我没有收到你的消息"
|
|
|
if need_retry:
|
|
|
time.sleep(5)
|
|
|
- elif isinstance(e, openai.error.APIError):
|
|
|
+ elif isinstance(e, openai.APIError):
|
|
|
logger.warn("[CHATGPT] Bad Gateway: {}".format(e))
|
|
|
result["content"] = "请再问我一次"
|
|
|
if need_retry:
|
|
|
time.sleep(10)
|
|
|
- elif isinstance(e, openai.error.APIConnectionError):
|
|
|
+ elif isinstance(e, openai.APIConnectionError):
|
|
|
logger.warn("[CHATGPT] APIConnectionError: {}".format(e))
|
|
|
result["content"] = "我连接不到你的网络"
|
|
|
if need_retry:
|
|
|
@@ -166,8 +171,6 @@ class ChatGPTBot(Bot, OpenAIImage):
|
|
|
class AzureChatGPTBot(ChatGPTBot):
|
|
|
def __init__(self):
|
|
|
super().__init__()
|
|
|
- openai.api_type = "azure"
|
|
|
- openai.api_version = conf().get("azure_api_version", "2023-06-01-preview")
|
|
|
self.args["deployment_id"] = conf().get("azure_deployment_id")
|
|
|
|
|
|
def create_img(self, query, retry_count=0, api_key=None):
|
|
|
@@ -186,8 +189,8 @@ class AzureChatGPTBot(ChatGPTBot):
|
|
|
logger.info("waiting for image create..., " + status + ",retry after " + retry_after + " seconds")
|
|
|
time.sleep(int(retry_after))
|
|
|
response = requests.get(operation_location, headers=headers)
|
|
|
- status = response.json()["status"]
|
|
|
- image_url = response.json()["result"]["contentUrl"]
|
|
|
+ status = response.json().status
|
|
|
+ image_url = response.json().result.contentUrl
|
|
|
return True, image_url
|
|
|
except Exception as e:
|
|
|
logger.error("create image error: {}".format(e))
|