Selaa lähdekoodia

compatible with openai bot

lanvent 3 vuotta sitten
vanhempi
säilyke
dce9c4dccb
3 muutettua tiedostoa jossa 30 lisäystä ja 25 poistoa
  1. 3 1
      bot/baidu/baidu_unit_bot.py
  2. 25 22
      bot/openai/open_ai_bot.py
  3. 2 2
      plugins/godcmd/godcmd.py

+ 3 - 1
bot/baidu/baidu_unit_bot.py

@@ -2,6 +2,7 @@
 
 import requests
 from bot.bot import Bot
+from bridge.reply import Reply, ReplyType
 
 
 # Baidu Unit对话接口 (可用, 但能力较弱)
@@ -14,7 +15,8 @@ class BaiduUnitBot(Bot):
         headers = {'content-type': 'application/x-www-form-urlencoded'}
         response = requests.post(url, data=post_data.encode(), headers=headers)
         if response:
-            return response.json()['result']['context']['SYS_PRESUMED_HIST'][1]
+            reply = Reply(ReplyType.TEXT, response.json()['result']['context']['SYS_PRESUMED_HIST'][1])
+            return reply
 
     def get_token(self):
         access_key = 'YOUR_ACCESS_KEY'

+ 25 - 22
bot/openai/open_ai_bot.py

@@ -1,6 +1,8 @@
 # encoding:utf-8
 
 from bot.bot import Bot
+from bridge.context import ContextType
+from bridge.reply import Reply, ReplyType
 from config import conf
 from common.log import logger
 import openai
@@ -13,30 +15,31 @@ class OpenAIBot(Bot):
     def __init__(self):
         openai.api_key = conf().get('open_ai_api_key')
 
-
     def reply(self, query, context=None):
         # acquire reply content
-        if not context or not context.get('type') or context.get('type') == 'TEXT':
-            logger.info("[OPEN_AI] query={}".format(query))
-            from_user_id = context.get('from_user_id') or context.get('session_id')
-            if query == '#清除记忆':
-                Session.clear_session(from_user_id)
-                return '记忆已清除'
-            elif query == '#清除所有':
-                Session.clear_all_session()
-                return '所有人记忆已清除'
-
-            new_query = Session.build_session_query(query, from_user_id)
-            logger.debug("[OPEN_AI] session query={}".format(new_query))
-
-            reply_content = self.reply_text(new_query, from_user_id, 0)
-            logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content))
-            if reply_content and query:
-                Session.save_session(query, reply_content, from_user_id)
-            return reply_content
-
-        elif context.get('type', None) == 'IMAGE_CREATE':
-            return self.create_img(query, 0)
+        if context and context.type:
+            if context.type == ContextType.TEXT:
+                logger.info("[OPEN_AI] query={}".format(query))
+                from_user_id = context['session_id']
+                reply = None
+                if query == '#清除记忆':
+                    Session.clear_session(from_user_id)
+                    reply = Reply(ReplyType.INFO, '记忆已清除')
+                elif query == '#清除所有':
+                    Session.clear_all_session()
+                    reply = Reply(ReplyType.INFO, '所有人记忆已清除')
+                else:
+                    new_query = Session.build_session_query(query, from_user_id)
+                    logger.debug("[OPEN_AI] session query={}".format(new_query))
+
+                    reply_content = self.reply_text(new_query, from_user_id, 0)
+                    logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content))
+                    if reply_content and query:
+                        Session.save_session(query, reply_content, from_user_id)
+                    reply = Reply(ReplyType.TEXT, reply_content)
+                return reply
+            elif context.type == ContextType.IMAGE_CREATE:
+                return self.create_img(query, 0)
 
     def reply_text(self, query, user_id, retry_count=0):
         try:

+ 2 - 2
plugins/godcmd/godcmd.py

@@ -162,7 +162,7 @@ class Godcmd(Plugin):
                         bot.sessions.clear_session(session_id)
                         ok, result = True, "会话已重置"
                     else:
-                        ok, result = False, "当前机器人不支持重置会话"
+                        ok, result = False, "当前对话机器人不支持重置会话"
                 logger.debug("[Godcmd] command: %s by %s" % (cmd, user))
             elif any(cmd in info['alias'] for info in ADMIN_COMMANDS.values()):
                 if isadmin:
@@ -184,7 +184,7 @@ class Godcmd(Plugin):
                                 bot.sessions.clear_all_session()
                                 ok, result = True, "重置所有会话成功"
                             else:
-                                ok, result = False, "当前机器人不支持重置会话"
+                                ok, result = False, "当前对话机器人不支持重置会话"
                         elif cmd == "debug":
                             logger.setLevel('DEBUG')
                             ok, result = True, "DEBUG模式已开启"