|
|
@@ -4,6 +4,7 @@ from bot.bot import Bot
|
|
|
from config import conf
|
|
|
from common.log import logger
|
|
|
import openai
|
|
|
+from datetime import date
|
|
|
|
|
|
user_session = dict()
|
|
|
|
|
|
@@ -25,14 +26,15 @@ class OpenAIBot(Bot):
|
|
|
new_query = Session.build_session_query(query, from_user_id)
|
|
|
logger.debug("[OPEN_AI] session query={}".format(new_query))
|
|
|
|
|
|
- reply_content = self.reply_text(new_query, query)
|
|
|
+ reply_content = self.reply_text(new_query, from_user_id)
|
|
|
+ logger.debug("[OPEN_AI] new_query={}, user={}".format(new_query, from_user_id))
|
|
|
Session.save_session(query, reply_content, from_user_id)
|
|
|
return reply_content
|
|
|
|
|
|
elif context.get('type', None) == 'IMAGE_CREATE':
|
|
|
return self.create_img(query)
|
|
|
|
|
|
- def reply_text(self, query, origin_query):
|
|
|
+ def reply_text(self, query, user_id):
|
|
|
try:
|
|
|
response = openai.Completion.create(
|
|
|
model="text-davinci-003", # 对话模型的名称
|
|
|
@@ -47,6 +49,7 @@ class OpenAIBot(Bot):
|
|
|
res_content = response.choices[0]["text"].strip().rstrip("<|im_end|>")
|
|
|
except Exception as e:
|
|
|
logger.exception(e)
|
|
|
+ Session.clear_session(user_id)
|
|
|
return None
|
|
|
logger.info("[OPEN_AI] reply={}".format(res_content))
|
|
|
return res_content
|
|
|
@@ -124,7 +127,7 @@ class Session(object):
|
|
|
if session:
|
|
|
for conversation in session:
|
|
|
prompt += "Q: " + conversation["question"] + "\n\n\nA: " + conversation["answer"] + "<|im_end|>\n"
|
|
|
- prompt += "Q: " + query + "\nA: "
|
|
|
+ prompt += "Q: " + query + "\nA: "
|
|
|
return prompt
|
|
|
else:
|
|
|
return prompt + "Q: " + query + "\nA: "
|
|
|
@@ -139,6 +142,8 @@ class Session(object):
|
|
|
conversation["question"] = query
|
|
|
conversation["answer"] = answer
|
|
|
session = user_session.get(user_id)
|
|
|
+ logger.debug(conversation)
|
|
|
+ logger.debug(session)
|
|
|
if session:
|
|
|
# append conversation
|
|
|
session.append(conversation)
|