|
|
@@ -40,21 +40,21 @@ class ChatGPTBot(Bot):
|
|
|
# return self.reply_text_stream(query, new_query, from_user_id)
|
|
|
|
|
|
reply_content = self.reply_text(new_query, from_user_id, 0)
|
|
|
- logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content))
|
|
|
- if reply_content:
|
|
|
- Session.save_session(query, reply_content, from_user_id)
|
|
|
- return reply_content[1]
|
|
|
+ logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content["content"]))
|
|
|
+ if reply_content["completion_tokens"] > 0:
|
|
|
+ Session.save_session(reply_content["content"], from_user_id, reply_content["total_tokens"])
|
|
|
+ return reply_content["content"]
|
|
|
|
|
|
elif context.get('type', None) == 'IMAGE_CREATE':
|
|
|
return self.create_img(query, 0)
|
|
|
|
|
|
- def reply_text(self, query, user_id, retry_count=0):
|
|
|
+ def reply_text(self, query, user_id, retry_count=0) ->dict:
|
|
|
'''
|
|
|
call openai's ChatCompletion to get the answer
|
|
|
:param query: query content
|
|
|
:param user_id: from user id
|
|
|
:param retry_count: retry count
|
|
|
- :return: [0]-tokens used and [1]-answer
|
|
|
+ :return: {}
|
|
|
'''
|
|
|
try:
|
|
|
response = openai.ChatCompletion.create(
|
|
|
@@ -68,8 +68,9 @@ class ChatGPTBot(Bot):
|
|
|
)
|
|
|
# res_content = response.choices[0]['text'].strip().replace('<|endoftext|>', '')
|
|
|
logger.info(response.choices[0]['message']['content'])
|
|
|
-
|
|
|
- return response["usage"]["prompt_tokens"],response.choices[0]['message']['content']
|
|
|
+ return {"total_tokens": response["usage"]["total_tokens"],
|
|
|
+ "completion_tokens": response["usage"]["completion_tokens"],
|
|
|
+ "content": response.choices[0]['message']['content']}
|
|
|
except openai.error.RateLimitError as e:
|
|
|
# rate limit exception
|
|
|
logger.warn(e)
|
|
|
@@ -78,21 +79,21 @@ class ChatGPTBot(Bot):
|
|
|
logger.warn("[OPEN_AI] RateLimit exceed, 第{}次重试".format(retry_count+1))
|
|
|
return self.reply_text(query, user_id, retry_count+1)
|
|
|
else:
|
|
|
- return 0,"提问太快啦,请休息一下再问我吧"
|
|
|
+ return {"completion_tokens": 0, "content": "提问太快啦,请休息一下再问我吧"}
|
|
|
except openai.error.APIConnectionError as e:
|
|
|
# api connection exception
|
|
|
logger.warn(e)
|
|
|
logger.warn("[OPEN_AI] APIConnection failed")
|
|
|
- return 0,"我连接不到你的网络"
|
|
|
+ return {"completion_tokens": 0, "content":"我连接不到你的网络"}
|
|
|
except openai.error.Timeout as e:
|
|
|
logger.warn(e)
|
|
|
logger.warn("[OPEN_AI] Timeout")
|
|
|
- return 0,"我没有收到你的消息"
|
|
|
+ return {"completion_tokens": 0, "content":"我没有收到你的消息"}
|
|
|
except Exception as e:
|
|
|
# unknown exception
|
|
|
logger.exception(e)
|
|
|
Session.clear_session(user_id)
|
|
|
- return 0,"请再问我一次吧"
|
|
|
+ return {"completion_tokens": 0, "content": "请再问我一次吧"}
|
|
|
|
|
|
def create_img(self, query, retry_count=0):
|
|
|
try:
|
|
|
@@ -143,7 +144,7 @@ class Session(object):
|
|
|
return session
|
|
|
|
|
|
@staticmethod
|
|
|
- def save_session(query, answer, user_id):
|
|
|
+ def save_session(answer, user_id, total_tokens):
|
|
|
max_tokens = conf().get("conversation_max_tokens")
|
|
|
if not max_tokens:
|
|
|
# default 3000
|
|
|
@@ -153,22 +154,23 @@ class Session(object):
|
|
|
session = user_session.get(user_id)
|
|
|
if session:
|
|
|
# append conversation
|
|
|
- gpt_item = {'role': 'assistant', 'content': answer[1]}
|
|
|
+ gpt_item = {'role': 'assistant', 'content': answer}
|
|
|
session.append(gpt_item)
|
|
|
|
|
|
# discard exceed limit conversation
|
|
|
- used_tokens=int(answer[0])
|
|
|
- # logger.info("prompt tokens used={},max_tokens={}".format(used_tokens,max_tokens))
|
|
|
+ Session.discard_exceed_conversation(session, max_tokens, total_tokens)
|
|
|
|
|
|
- while used_tokens > max_tokens:
|
|
|
+ @staticmethod
|
|
|
+ def discard_exceed_conversation(session, max_tokens, total_tokens):
|
|
|
+ dec_tokens=int(total_tokens)
|
|
|
+ # logger.info("prompt tokens used={},max_tokens={}".format(used_tokens,max_tokens))
|
|
|
+ while dec_tokens > max_tokens:
|
|
|
# pop first conversation
|
|
|
if len(session) > 0:
|
|
|
session.pop(0)
|
|
|
else:
|
|
|
break
|
|
|
-
|
|
|
- used_tokens=used_tokens-max_tokens
|
|
|
-
|
|
|
+ dec_tokens=dec_tokens-max_tokens
|
|
|
|
|
|
@staticmethod
|
|
|
def clear_session(user_id):
|