|
|
@@ -7,14 +7,15 @@ import requests
|
|
|
|
|
|
from bot.bot import Bot
|
|
|
from bot.chatgpt.chat_gpt_session import ChatGPTSession
|
|
|
+from bot.openai.open_ai_image import OpenAIImage
|
|
|
from bot.session_manager import SessionManager
|
|
|
-from bridge.context import Context
|
|
|
+from bridge.context import Context, ContextType
|
|
|
from bridge.reply import Reply, ReplyType
|
|
|
from common.log import logger
|
|
|
from config import conf
|
|
|
|
|
|
|
|
|
-class LinkAIBot(Bot):
|
|
|
+class LinkAIBot(Bot, OpenAIImage):
|
|
|
# authentication failed
|
|
|
AUTH_FAILED_CODE = 401
|
|
|
NO_QUOTA_CODE = 406
|
|
|
@@ -24,7 +25,19 @@ class LinkAIBot(Bot):
|
|
|
self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo")
|
|
|
|
|
|
def reply(self, query, context: Context = None) -> Reply:
|
|
|
- return self._chat(query, context)
|
|
|
+ if context.type == ContextType.TEXT:
|
|
|
+ return self._chat(query, context)
|
|
|
+ elif context.type == ContextType.IMAGE_CREATE:
|
|
|
+ ok, retstring = self.create_img(query, 0)
|
|
|
+ reply = None
|
|
|
+ if ok:
|
|
|
+ reply = Reply(ReplyType.IMAGE_URL, retstring)
|
|
|
+ else:
|
|
|
+ reply = Reply(ReplyType.ERROR, retstring)
|
|
|
+ return reply
|
|
|
+ else:
|
|
|
+ reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
|
|
|
+ return reply
|
|
|
|
|
|
def _chat(self, query, context, retry_count=0):
|
|
|
if retry_count >= 2:
|