|
|
@@ -1,31 +1,78 @@
|
|
|
from bridge.context import *
|
|
|
-from channel.channel import Channel
|
|
|
+from bridge.reply import Reply, ReplyType
|
|
|
+from channel.chat_channel import ChatChannel, check_prefix
|
|
|
+from channel.chat_message import ChatMessage
|
|
|
import sys
|
|
|
|
|
|
-class TerminalChannel(Channel):
|
|
|
+from config import conf
|
|
|
+from common.log import logger
|
|
|
+
|
|
|
+class TerminalMessage(ChatMessage):
|
|
|
+ def __init__(self, msg_id, content, ctype = ContextType.TEXT, from_user_id = "User", to_user_id = "Chatgpt", other_user_id = "Chatgpt"):
|
|
|
+ self.msg_id = msg_id
|
|
|
+ self.ctype = ctype
|
|
|
+ self.content = content
|
|
|
+ self.from_user_id = from_user_id
|
|
|
+ self.to_user_id = to_user_id
|
|
|
+ self.other_user_id = other_user_id
|
|
|
+
|
|
|
+class TerminalChannel(ChatChannel):
|
|
|
+ NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE]
|
|
|
+
|
|
|
+ def send(self, reply: Reply, context: Context):
|
|
|
+ print("\nBot:")
|
|
|
+ if reply.type == ReplyType.IMAGE:
|
|
|
+ from PIL import Image
|
|
|
+ image_storage = reply.content
|
|
|
+ image_storage.seek(0)
|
|
|
+ img = Image.open(image_storage)
|
|
|
+ print("<IMAGE>")
|
|
|
+ img.show()
|
|
|
+ elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
|
|
|
+ from PIL import Image
|
|
|
+ import requests,io
|
|
|
+ img_url = reply.content
|
|
|
+ pic_res = requests.get(img_url, stream=True)
|
|
|
+ image_storage = io.BytesIO()
|
|
|
+ for block in pic_res.iter_content(1024):
|
|
|
+ image_storage.write(block)
|
|
|
+ image_storage.seek(0)
|
|
|
+ img = Image.open(image_storage)
|
|
|
+ print(img_url)
|
|
|
+ img.show()
|
|
|
+ else:
|
|
|
+ print(reply.content)
|
|
|
+ print("\nUser:", end="")
|
|
|
+ sys.stdout.flush()
|
|
|
+ return
|
|
|
+
|
|
|
def startup(self):
|
|
|
context = Context()
|
|
|
- print("\nPlease input your question")
|
|
|
+ logger.setLevel("WARN")
|
|
|
+ print("\nPlease input your question:\nUser:", end="")
|
|
|
+ sys.stdout.flush()
|
|
|
+ msg_id = 0
|
|
|
while True:
|
|
|
try:
|
|
|
- prompt = self.get_input("User:\n")
|
|
|
+ prompt = self.get_input()
|
|
|
except KeyboardInterrupt:
|
|
|
print("\nExiting...")
|
|
|
sys.exit()
|
|
|
+ msg_id += 1
|
|
|
+ trigger_prefixs = conf().get("single_chat_prefix",[""])
|
|
|
+ if check_prefix(prompt, trigger_prefixs) is None:
|
|
|
+ prompt = trigger_prefixs[0] + prompt # 给没触发的消息加上触发前缀
|
|
|
+
|
|
|
+ context = self._compose_context(ContextType.TEXT, prompt, msg = TerminalMessage(msg_id, prompt))
|
|
|
+ if context:
|
|
|
+ self.produce(context)
|
|
|
+ else:
|
|
|
+ raise Exception("context is None")
|
|
|
|
|
|
- context.type = ContextType.TEXT
|
|
|
- context['session_id'] = "User"
|
|
|
- context.content = prompt
|
|
|
- print("Bot:")
|
|
|
- sys.stdout.flush()
|
|
|
- res = super().build_reply_content(prompt, context).content
|
|
|
- print(res)
|
|
|
-
|
|
|
-
|
|
|
- def get_input(self, prompt):
|
|
|
+ def get_input(self):
|
|
|
"""
|
|
|
Multi-line input function
|
|
|
"""
|
|
|
- print(prompt, end="")
|
|
|
+ sys.stdout.flush()
|
|
|
line = input()
|
|
|
return line
|