|
|
@@ -2,6 +2,8 @@
|
|
|
|
|
|
from bot.bot import Bot
|
|
|
from bot.openai.open_ai_image import OpenAIImage
|
|
|
+from bot.openai.open_ai_session import OpenAISession
|
|
|
+from bot.session_manager import SessionManager
|
|
|
from bridge.context import ContextType
|
|
|
from bridge.reply import Reply, ReplyType
|
|
|
from config import conf
|
|
|
@@ -22,29 +24,34 @@ class OpenAIBot(Bot, OpenAIImage):
|
|
|
if proxy:
|
|
|
openai.proxy = proxy
|
|
|
|
|
|
+ self.sessions = SessionManager(OpenAISession, model= conf().get("model") or "text-davinci-003")
|
|
|
|
|
|
def reply(self, query, context=None):
|
|
|
# acquire reply content
|
|
|
if context and context.type:
|
|
|
if context.type == ContextType.TEXT:
|
|
|
logger.info("[OPEN_AI] query={}".format(query))
|
|
|
- from_user_id = context['session_id']
|
|
|
+ session_id = context['session_id']
|
|
|
reply = None
|
|
|
if query == '#清除记忆':
|
|
|
- Session.clear_session(from_user_id)
|
|
|
+ self.sessions.clear_session(session_id)
|
|
|
reply = Reply(ReplyType.INFO, '记忆已清除')
|
|
|
elif query == '#清除所有':
|
|
|
- Session.clear_all_session()
|
|
|
+ self.sessions.clear_all_session()
|
|
|
reply = Reply(ReplyType.INFO, '所有人记忆已清除')
|
|
|
else:
|
|
|
- new_query = Session.build_session_query(query, from_user_id)
|
|
|
+ session = self.sessions.session_query(query, session_id)
|
|
|
+ new_query = str(session)
|
|
|
logger.debug("[OPEN_AI] session query={}".format(new_query))
|
|
|
|
|
|
- reply_content = self.reply_text(new_query, from_user_id, 0)
|
|
|
- logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content))
|
|
|
- if reply_content and query:
|
|
|
- Session.save_session(query, reply_content, from_user_id)
|
|
|
- reply = Reply(ReplyType.TEXT, reply_content)
|
|
|
+ total_tokens, completion_tokens, reply_content = self.reply_text(new_query, session_id, 0)
|
|
|
+ logger.debug("[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(new_query, session_id, reply_content, completion_tokens))
|
|
|
+
|
|
|
+ if total_tokens == 0 :
|
|
|
+ reply = Reply(ReplyType.ERROR, reply_content)
|
|
|
+ else:
|
|
|
+ self.sessions.session_reply(reply_content, session_id, total_tokens)
|
|
|
+ reply = Reply(ReplyType.TEXT, reply_content)
|
|
|
return reply
|
|
|
elif context.type == ContextType.IMAGE_CREATE:
|
|
|
ok, retstring = self.create_img(query, 0)
|
|
|
@@ -68,8 +75,10 @@ class OpenAIBot(Bot, OpenAIImage):
|
|
|
stop=["\n\n\n"]
|
|
|
)
|
|
|
res_content = response.choices[0]['text'].strip().replace('<|endoftext|>', '')
|
|
|
+ total_tokens = response["usage"]["total_tokens"]
|
|
|
+ completion_tokens = response["usage"]["completion_tokens"]
|
|
|
logger.info("[OPEN_AI] reply={}".format(res_content))
|
|
|
- return res_content
|
|
|
+ return total_tokens, completion_tokens, res_content
|
|
|
except openai.error.RateLimitError as e:
|
|
|
# rate limit exception
|
|
|
logger.warn(e)
|
|
|
@@ -78,81 +87,9 @@ class OpenAIBot(Bot, OpenAIImage):
|
|
|
logger.warn("[OPEN_AI] RateLimit exceed, 第{}次重试".format(retry_count+1))
|
|
|
return self.reply_text(query, user_id, retry_count+1)
|
|
|
else:
|
|
|
- return "提问太快啦,请休息一下再问我吧"
|
|
|
+ return 0,0, "提问太快啦,请休息一下再问我吧"
|
|
|
except Exception as e:
|
|
|
# unknown exception
|
|
|
logger.exception(e)
|
|
|
Session.clear_session(user_id)
|
|
|
- return "请再问我一次吧"
|
|
|
-
|
|
|
-class Session(object):
|
|
|
- @staticmethod
|
|
|
- def build_session_query(query, user_id):
|
|
|
- '''
|
|
|
- build query with conversation history
|
|
|
- e.g. Q: xxx
|
|
|
- A: xxx
|
|
|
- Q: xxx
|
|
|
- :param query: query content
|
|
|
- :param user_id: from user id
|
|
|
- :return: query content with conversaction
|
|
|
- '''
|
|
|
- prompt = conf().get("character_desc", "")
|
|
|
- if prompt:
|
|
|
- prompt += "<|endoftext|>\n\n\n"
|
|
|
- session = user_session.get(user_id, None)
|
|
|
- if session:
|
|
|
- for conversation in session:
|
|
|
- prompt += "Q: " + conversation["question"] + "\n\n\nA: " + conversation["answer"] + "<|endoftext|>\n"
|
|
|
- prompt += "Q: " + query + "\nA: "
|
|
|
- return prompt
|
|
|
- else:
|
|
|
- return prompt + "Q: " + query + "\nA: "
|
|
|
-
|
|
|
- @staticmethod
|
|
|
- def save_session(query, answer, user_id):
|
|
|
- max_tokens = conf().get("conversation_max_tokens")
|
|
|
- if not max_tokens:
|
|
|
- # default 3000
|
|
|
- max_tokens = 1000
|
|
|
- conversation = dict()
|
|
|
- conversation["question"] = query
|
|
|
- conversation["answer"] = answer
|
|
|
- session = user_session.get(user_id)
|
|
|
- logger.debug(conversation)
|
|
|
- logger.debug(session)
|
|
|
- if session:
|
|
|
- # append conversation
|
|
|
- session.append(conversation)
|
|
|
- else:
|
|
|
- # create session
|
|
|
- queue = list()
|
|
|
- queue.append(conversation)
|
|
|
- user_session[user_id] = queue
|
|
|
-
|
|
|
- # discard exceed limit conversation
|
|
|
- Session.discard_exceed_conversation(user_session[user_id], max_tokens)
|
|
|
-
|
|
|
-
|
|
|
- @staticmethod
|
|
|
- def discard_exceed_conversation(session, max_tokens):
|
|
|
- count = 0
|
|
|
- count_list = list()
|
|
|
- for i in range(len(session)-1, -1, -1):
|
|
|
- # count tokens of conversation list
|
|
|
- history_conv = session[i]
|
|
|
- count += len(history_conv["question"]) + len(history_conv["answer"])
|
|
|
- count_list.append(count)
|
|
|
-
|
|
|
- for c in count_list:
|
|
|
- if c > max_tokens:
|
|
|
- # pop first conversation
|
|
|
- session.pop(0)
|
|
|
-
|
|
|
- @staticmethod
|
|
|
- def clear_session(user_id):
|
|
|
- user_session[user_id] = []
|
|
|
-
|
|
|
- @staticmethod
|
|
|
- def clear_all_session():
|
|
|
- user_session.clear()
|
|
|
+ return 0,0, "请再问我一次吧"
|