Browse Source

feat: add config for model selection #471

zhayujie 3 năm trước cách đây
mục cha
commit
3c04325aae

+ 2 - 6
README.md

@@ -96,12 +96,8 @@ pip3 install --upgrade openai
 # config.json文件内容示例
 { 
   "open_ai_api_key": "YOUR API KEY",                          # 填入上面创建的 OpenAI API KEY
-  "open_ai_api_base": "https://api.openai.com/v1",            # 自定义 OpenAI API 地址
+  "model": "gpt-3.5-turbo",                                   # 模型名称
   "proxy": "127.0.0.1:7890",                                  # 代理客户端的ip和端口
-  "baidu_app_id": "",                                         # 百度AI的App Id
-  "baidu_api_key": "",                                        # 百度AI的API KEY
-  "baidu_secret_key": "",                                     # 百度AI的Secret KEY
-  "wechaty_puppet_service_token":"",                          # wechaty服务token
   "single_chat_prefix": ["bot", "@bot"],                      # 私聊时文本需要包含该前缀才能触发机器人回复
   "single_chat_reply_prefix": "[bot] ",                       # 私聊时自动回复的前缀,用于区分真人
   "group_chat_prefix": ["@bot"],                              # 群聊时包含该前缀则会触发机器人回复
@@ -109,7 +105,6 @@ pip3 install --upgrade openai
   "image_create_prefix": ["画", "看", "找"],                   # 开启图片回复的前缀
   "conversation_max_tokens": 1000,                            # 支持上下文记忆的最多字符数
   "speech_recognition": false,                                # 是否开启语音识别
-  "voice_reply_voice": false,                                 # 是否开启语音回复
   "character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。",  # 人格描述
 }
 ```
@@ -133,6 +128,7 @@ pip3 install --upgrade openai
 
 **4.其他配置**
 
++ `model`: 模型名称,目前支持 `gpt-3.5-turbo`, `text-davinci-003`, `gpt-4`, `gpt-4-32k`  (其中gpt-4 api暂未开放)
 + `proxy`:由于目前 `openai` 接口国内无法访问,需配置代理客户端的地址,详情参考  [#351](https://github.com/zhayujie/chatgpt-on-wechat/issues/351)
 + 对于图像生成,在满足个人或群组触发条件外,还需要额外的关键词前缀来触发,对应配置 `image_create_prefix `
 + 关于OpenAI对话及图片接口的参数配置(内容自由度、回复字数限制、图片大小等),可以参考 [对话接口](https://beta.openai.com/docs/api-reference/completions) 和 [图像接口](https://beta.openai.com/docs/api-reference/completions)  文档直接在 [代码](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/bot/openai/open_ai_bot.py) `bot/openai/open_ai_bot.py` 中进行调整。

+ 4 - 3
bot/bot_factory.py

@@ -1,6 +1,7 @@
 """
 channel factory
 """
+from common import const
 
 
 def create_bot(bot_type):
@@ -9,17 +10,17 @@ def create_bot(bot_type):
     :param channel_type: channel type code
     :return: channel instance
     """
-    if bot_type == 'baidu':
+    if bot_type == const.BAIDU:
         # Baidu Unit对话接口
         from bot.baidu.baidu_unit_bot import BaiduUnitBot
         return BaiduUnitBot()
 
-    elif bot_type == 'chatGPT':
+    elif bot_type == const.CHATGPT:
         # ChatGPT 网页端web接口
         from bot.chatgpt.chat_gpt_bot import ChatGPTBot
         return ChatGPTBot()
 
-    elif bot_type == 'openAI':
+    elif bot_type == const.OPEN_AI:
         # OpenAI 官方对话模型API
         from bot.openai.open_ai_bot import OpenAIBot
         return OpenAIBot()

+ 1 - 1
bot/chatgpt/chat_gpt_bot.py

@@ -63,7 +63,7 @@ class ChatGPTBot(Bot):
         '''
         try:
             response = openai.ChatCompletion.create(
-                model="gpt-3.5-turbo",  # 对话模型的名称
+                model= conf().get("model") or "gpt-3.5-turbo",  # 对话模型的名称
                 messages=session,
                 temperature=0.9,  # 值在[0,1]之间,越大表示回复越具有不确定性
                 #max_tokens=4096,  # 回复最大的字符数

+ 1 - 1
bot/openai/open_ai_bot.py

@@ -45,7 +45,7 @@ class OpenAIBot(Bot):
     def reply_text(self, query, user_id, retry_count=0):
         try:
             response = openai.Completion.create(
-                model="text-davinci-003",  # 对话模型的名称
+                model= conf().get("model") or "text-davinci-003",  # 对话模型的名称
                 prompt=query,
                 temperature=0.9,  # 值在[0,1]之间,越大表示回复越具有不确定性
                 max_tokens=1200,  # 回复最大的字符数

+ 9 - 1
bridge/bridge.py

@@ -1,5 +1,7 @@
 from bot import bot_factory
 from voice import voice_factory
+from config import conf
+from common import const
 
 
 class Bridge(object):
@@ -7,7 +9,13 @@ class Bridge(object):
         pass
 
     def fetch_reply_content(self, query, context):
-        return bot_factory.create_bot("chatGPT").reply(query, context)
+        bot_type = const.CHATGPT
+        model_type = conf().get("model")
+        if model_type in ["gpt-3.5-turbo", "gpt-4", "gpt-4-32k"]:
+            bot_type = const.CHATGPT
+        elif model_type in ["text-davinci-003"]:
+            bot_type = const.OPEN_AI
+        return bot_factory.create_bot(bot_type).reply(query, context)
 
     def fetch_voice_to_text(self, voiceFile):
         return voice_factory.create_voice("openai").voiceToText(voiceFile)

+ 3 - 0
channel/channel_factory.py

@@ -14,4 +14,7 @@ def create_channel(channel_type):
     elif channel_type == 'wxy':
         from channel.wechat.wechaty_channel import WechatyChannel
         return WechatyChannel()
+    elif channel_type == 'terminal':
+        from channel.terminal.terminal_channel import TerminalChannel
+        return TerminalChannel()
     raise RuntimeError

+ 29 - 0
channel/terminal/terminal_channel.py

@@ -0,0 +1,29 @@
+from channel.channel import Channel
+import sys
+
+class TerminalChannel(Channel):
+    def startup(self):
+        context = {"from_user_id": "User"}
+        print("\nPlease input your question")
+        while True:
+            try:
+                prompt = self.get_input("User:\n")
+            except KeyboardInterrupt:
+                print("\nExiting...")
+                sys.exit()
+
+            print("Bot:")
+            sys.stdout.flush()
+            for res in super().build_reply_content(prompt, context):
+                print(res, end="")
+                sys.stdout.flush()
+            print("\n")
+
+
+    def get_input(self, prompt):
+        """
+        Multi-line input function
+        """
+        print(prompt, end="")
+        line = input()
+        return line

+ 4 - 0
common/const.py

@@ -0,0 +1,4 @@
+# bot_type
+OPEN_AI = "openAI"
+CHATGPT = "chatGPT"
+BAIDU = "baidu"

+ 1 - 0
config-template.json

@@ -1,5 +1,6 @@
 {
   "open_ai_api_key": "YOUR API KEY",
+  "model": "gpt-3.5-turbo",
   "proxy": "",
   "single_chat_prefix": ["bot", "@bot"],
   "single_chat_reply_prefix": "[bot] ",