Kaynağa Gözat

Merge branch 'master' of github.com:zwssunny/chatgpt-on-wechat into zwssunny-master

lanvent 3 yıl önce
ebeveyn
işleme
80bf6a0c7a

+ 3 - 3
app.py

@@ -1,6 +1,6 @@
 # encoding:utf-8
 
-import config
+from config import conf, load_config
 from channel import channel_factory
 from common.log import logger
 
@@ -9,10 +9,10 @@ from plugins import *
 def run():
     try:
         # load config
-        config.load_config()
+        load_config()
 
         # create channel
-        channel_name='wx'
+        channel_name=conf().get('channel_type', 'wx')
         channel = channel_factory.create_channel(channel_name)
         if channel_name=='wx':
             PluginManager().load_plugins()

+ 3 - 3
bot/bot_factory.py

@@ -6,9 +6,9 @@ from common import const
 
 def create_bot(bot_type):
     """
-    create a channel instance
-    :param channel_type: channel type code
-    :return: channel instance
+    create a bot_type instance
+    :param bot_type: bot type code
+    :return: bot instance
     """
     if bot_type == const.BAIDU:
         # Baidu Unit对话接口

+ 85 - 48
channel/wechat/wechat_channel.py

@@ -5,6 +5,9 @@ wechat channel
 """
 
 import os
+import requests
+import io
+import time
 from lib import itchat
 import json
 from lib.itchat.content import *
@@ -17,17 +20,18 @@ from common.tmp_dir import TmpDir
 from config import conf
 from common.time_check import time_checker
 from plugins import *
-import requests
-import io
-import time
+from voice.audio_convert import mp3_to_wav
 
 
 thread_pool = ThreadPoolExecutor(max_workers=8)
+
+
 def thread_pool_callback(worker):
     worker_exception = worker.exception()
     if worker_exception:
         logger.exception("Worker return exception: {}".format(worker_exception))
 
+
 @itchat.msg_register(TEXT)
 def handler_single_msg(msg):
     WechatChannel().handle_text(msg)
@@ -48,6 +52,8 @@ def handler_group_voice(msg):
     WechatChannel().handle_group_voice(msg)
     return None
 
+
+
 class WechatChannel(Channel):
     def __init__(self):
         self.userName = None
@@ -55,14 +61,15 @@ class WechatChannel(Channel):
 
     def startup(self):
 
-        itchat.instance.receivingRetryCount = 600 # 修改断线超时时间
+        itchat.instance.receivingRetryCount = 600  # 修改断线超时时间
         # login by scan QRCode
         hotReload = conf().get('hot_reload', False)
         try:
             itchat.auto_login(enableCmdQR=2, hotReload=hotReload)
         except Exception as e:
             if hotReload:
-                logger.error("Hot reload failed, try to login without hot reload")
+                logger.error(
+                    "Hot reload failed, try to login without hot reload")
                 itchat.logout()
                 os.remove("itchat.pkl")
                 itchat.auto_login(enableCmdQR=2, hotReload=hotReload)
@@ -105,7 +112,8 @@ class WechatChannel(Channel):
 
     @time_checker
     def handle_text(self, msg):
-        logger.debug("[WX]receive text msg: " + json.dumps(msg, ensure_ascii=False))
+        logger.debug("[WX]receive text msg: " +
+                     json.dumps(msg, ensure_ascii=False))
         content = msg['Text']
         from_user_id = msg['FromUserName']
         to_user_id = msg['ToUserName']              # 接收人id
@@ -119,7 +127,7 @@ class WechatChannel(Channel):
                 other_user_id = from_user_id
         create_time = msg['CreateTime']             # 消息时间
         match_prefix = check_prefix(content, conf().get('single_chat_prefix'))
-        if conf().get('hot_reload') == True and int(create_time) < int(time.time()) - 60:    #跳过1分钟前的历史消息
+        if conf().get('hot_reload') == True and int(create_time) < int(time.time()) - 60:  # 跳过1分钟前的历史消息
             logger.debug("[WX]history message skipped")
             return
         if "」\n- - - - - - - - - - - - - - -" in content:
@@ -130,9 +138,11 @@ class WechatChannel(Channel):
         elif match_prefix is None:
             return
         context = Context()
-        context.kwargs = {'isgroup': False, 'msg': msg, 'receiver': other_user_id, 'session_id': other_user_id}
+        context.kwargs = {'isgroup': False, 'msg': msg,
+                          'receiver': other_user_id, 'session_id': other_user_id}
 
-        img_match_prefix = check_prefix(content, conf().get('image_create_prefix'))
+        img_match_prefix = check_prefix(
+            content, conf().get('image_create_prefix'))
         if img_match_prefix:
             content = content.replace(img_match_prefix, '', 1).strip()
             context.type = ContextType.IMAGE_CREATE
@@ -140,15 +150,17 @@ class WechatChannel(Channel):
             context.type = ContextType.TEXT
 
         context.content = content
-        thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback)
+        thread_pool.submit(self.handle, context).add_done_callback(
+            thread_pool_callback)
 
     @time_checker
     def handle_group(self, msg):
-        logger.debug("[WX]receive group msg: " + json.dumps(msg, ensure_ascii=False))
+        logger.debug("[WX]receive group msg: " +
+                     json.dumps(msg, ensure_ascii=False))
         group_name = msg['User'].get('NickName', None)
         group_id = msg['User'].get('UserName', None)
         create_time = msg['CreateTime']             # 消息时间
-        if conf().get('hot_reload') == True and int(create_time) < int(time.time()) - 60:    #跳过1分钟前的历史消息
+        if conf().get('hot_reload') == True and int(create_time) < int(time.time()) - 60:  # 跳过1分钟前的历史消息
             logger.debug("[WX]history group message skipped")
             return
         if not group_name:
@@ -166,12 +178,14 @@ class WechatChannel(Channel):
             return ""
         config = conf()
         match_prefix = (msg['IsAt'] and not config.get("group_at_off", False)) or check_prefix(origin_content, config.get('group_chat_prefix')) \
-                       or check_contain(origin_content, config.get('group_chat_keyword'))
+            or check_contain(origin_content, config.get('group_chat_keyword'))
         if ('ALL_GROUP' in config.get('group_name_white_list') or group_name in config.get('group_name_white_list') or check_contain(group_name, config.get('group_name_keyword_white_list'))) and match_prefix:
             context = Context()
-            context.kwargs = { 'isgroup': True, 'msg': msg, 'receiver': group_id}
-            
-            img_match_prefix = check_prefix(content, conf().get('image_create_prefix'))
+            context.kwargs = {'isgroup': True,
+                              'msg': msg, 'receiver': group_id}
+
+            img_match_prefix = check_prefix(
+                content, conf().get('image_create_prefix'))
             if img_match_prefix:
                 content = content.replace(img_match_prefix, '', 1).strip()
                 context.type = ContextType.IMAGE_CREATE
@@ -187,7 +201,8 @@ class WechatChannel(Channel):
             else:
                 context['session_id'] = msg['ActualUserName']
 
-            thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback)
+            thread_pool.submit(self.handle, context).add_done_callback(
+                thread_pool_callback)
 
     def handle_group_voice(self, msg):
         if conf().get('group_speech_recognition', False) != True:
@@ -217,7 +232,7 @@ class WechatChannel(Channel):
             thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback)
 
     # 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
-    def send(self, reply : Reply, receiver):
+    def send(self, reply: Reply, receiver):
         if reply.type == ReplyType.TEXT:
             itchat.send(reply.content, toUserName=receiver)
             logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver))
@@ -226,8 +241,9 @@ class WechatChannel(Channel):
             logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver))
         elif reply.type == ReplyType.VOICE:
             itchat.send_file(reply.content, toUserName=receiver)
-            logger.info('[WX] sendFile={}, receiver={}'.format(reply.content, receiver))
-        elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
+            logger.info('[WX] sendFile={}, receiver={}'.format(
+                reply.content, receiver))
+        elif reply.type == ReplyType.IMAGE_URL:  # 从网络下载图片
             img_url = reply.content
             pic_res = requests.get(img_url, stream=True)
             image_storage = io.BytesIO()
@@ -235,8 +251,9 @@ class WechatChannel(Channel):
                 image_storage.write(block)
             image_storage.seek(0)
             itchat.send_image(image_storage, toUserName=receiver)
-            logger.info('[WX] sendImage url={}, receiver={}'.format(img_url,receiver))
-        elif reply.type == ReplyType.IMAGE: # 从文件读取图片
+            logger.info('[WX] sendImage url={}, receiver={}'.format(
+                img_url, receiver))
+        elif reply.type == ReplyType.IMAGE:  # 从文件读取图片
             image_storage = reply.content
             image_storage.seek(0)
             itchat.send_image(image_storage, toUserName=receiver)
@@ -250,9 +267,10 @@ class WechatChannel(Channel):
         reply = Reply()
 
         logger.debug('[WX] ready to handle context: {}'.format(context))
-        
+
         # reply的构建步骤
-        e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, {'channel' : self, 'context': context, 'reply': reply}))
+        e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, {
+            'channel': self, 'context': context, 'reply': reply}))
         reply = e_context['reply']
         if not e_context.is_pass():
             logger.debug('[WX] ready to handle context: type={}, content={}'.format(context.type, context.content))
@@ -260,22 +278,35 @@ class WechatChannel(Channel):
                 reply = super().build_reply_content(context.content, context)
             elif context.type == ContextType.VOICE: # 语音消息
                 msg = context['msg']
-                file_name = TmpDir().path() + context.content
-                msg.download(file_name)
-                reply = super().build_voice_to_text(file_name)
-                if reply.type == ReplyType.TEXT:
-                    content = reply.content # 语音转文字后,将文字内容作为新的context
-                    # 如果是群消息,判断是否触发关键字
-                    if context['isgroup']:
-                        match_prefix = check_prefix(content, conf().get('group_chat_prefix'))
-                        match_contain = check_contain(content, conf().get('group_chat_keyword'))
-                        logger.debug('[WX] group chat prefix match: {}'.format(match_prefix))
-                        if match_prefix is None and match_contain is None:
-                            return
+                mp3_path = TmpDir().path() + context.content
+                msg.download(mp3_path)
+                # mp3转wav
+                wav_path = os.path.splitext(mp3_path)[0] + '.wav'
+                mp3_to_wav(mp3_path=mp3_path, wav_path=wav_path)
+                # 语音识别
+                reply = super().build_voice_to_text(wav_path)
+                # 删除临时文件
+                os.remove(wav_path)
+                os.remove(mp3_path)
+                if reply.type != ReplyType.ERROR and reply.type != ReplyType.INFO:
+                    content = reply.content  # 语音转文字后,将文字内容作为新的context
+                    context.type = ContextType.TEXT
+                    if (context["isgroup"] == True):
+                        # 校验关键字
+                        match_prefix = check_prefix(content, conf().get('group_chat_prefix')) \
+                            or check_contain(content, conf().get('group_chat_keyword'))
+                        # Wechaty判断is_at为True,返回的内容是过滤掉@之后的内容;而is_at为False,则会返回完整的内容
+                        if match_prefix is not None:
+                            # 故判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容,用于实现类似自定义+前缀触发生成AI图片的功能
+                            prefixes = conf().get('group_chat_prefix')
+                            for prefix in prefixes:
+                                if content.startswith(prefix):
+                                    content = content.replace(prefix, '', 1).strip()
+                                    break
                         else:
-                            if match_prefix:
-                                content = content.replace(match_prefix, '', 1).strip()
-                        
+                            logger.info("[WX]receive voice check prefix: " + 'False')
+                            return
+                       
                     img_match_prefix = check_prefix(content, conf().get('image_create_prefix'))
                     if img_match_prefix:
                         content = content.replace(img_match_prefix, '', 1).strip()
@@ -292,16 +323,19 @@ class WechatChannel(Channel):
                 return
 
         logger.debug('[WX] ready to decorate reply: {}'.format(reply))
-        
+
         # reply的包装步骤
         if reply and reply.type:
-            e_context = PluginManager().emit_event(EventContext(Event.ON_DECORATE_REPLY, {'channel' : self, 'context': context, 'reply': reply}))
-            reply=e_context['reply']
+            e_context = PluginManager().emit_event(EventContext(Event.ON_DECORATE_REPLY, {
+                'channel': self, 'context': context, 'reply': reply}))
+            reply = e_context['reply']
             if not e_context.is_pass() and reply and reply.type:
                 if reply.type == ReplyType.TEXT:
                     reply_text = reply.content
                     if context['isgroup']:
-                        reply_text = '@' +  context['msg']['ActualNickName'] + ' ' + reply_text.strip()
+                        reply_text = '@' + \
+                            context['msg']['ActualNickName'] + \
+                            ' ' + reply_text.strip()
                         reply_text = conf().get("group_chat_reply_prefix", "")+reply_text
                     else:
                         reply_text = conf().get("single_chat_reply_prefix", "")+reply_text
@@ -311,15 +345,18 @@ class WechatChannel(Channel):
                 elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE:
                     pass
                 else:
-                    logger.error('[WX] unknown reply type: {}'.format(reply.type))
+                    logger.error(
+                        '[WX] unknown reply type: {}'.format(reply.type))
                     return
 
-        # reply的发送步骤   
+        # reply的发送步骤
         if reply and reply.type:
-            e_context = PluginManager().emit_event(EventContext(Event.ON_SEND_REPLY, {'channel' : self, 'context': context, 'reply': reply}))
-            reply=e_context['reply']
+            e_context = PluginManager().emit_event(EventContext(Event.ON_SEND_REPLY, {
+                'channel': self, 'context': context, 'reply': reply}))
+            reply = e_context['reply']
             if not e_context.is_pass() and reply and reply.type:
-                logger.debug('[WX] ready to send reply: {} to {}'.format(reply, context['receiver']))
+                logger.debug('[WX] ready to send reply: {} to {}'.format(
+                    reply, context['receiver']))
                 self.send(reply, context['receiver'])
 
 def check_prefix(content, prefix_list):

+ 17 - 66
channel/wechat/wechaty_channel.py

@@ -4,25 +4,19 @@
 wechaty channel
 Python Wechaty - https://github.com/wechaty/python-wechaty
 """
-import io
 import os
-import json
 import time
 import asyncio
-import requests
-import pysilk
-import wave
-from pydub import AudioSegment
 from typing import Optional, Union
 from bridge.context import Context, ContextType
 from wechaty_puppet import MessageType, FileBox, ScanStatus  # type: ignore
 from wechaty import Wechaty, Contact
-from wechaty.user import Message, Room, MiniProgram, UrlLink
+from wechaty.user import Message, MiniProgram, UrlLink
 from channel.channel import Channel
 from common.log import logger
 from common.tmp_dir import TmpDir
 from config import conf
-
+from voice.audio_convert import sil_to_wav, mp3_to_sil
 
 class WechatyChannel(Channel):
 
@@ -50,8 +44,9 @@ class WechatyChannel(Channel):
 
     async def on_scan(self, status: ScanStatus, qr_code: Optional[str] = None,
                       data: Optional[str] = None):
-        contact = self.Contact.load(self.contact_id)
-        logger.info('[WX] scan user={}, scan status={}, scan qr_code={}'.format(contact, status.name, qr_code))
+        pass
+        # contact = self.Contact.load(self.contact_id)
+        # logger.info('[WX] scan user={}, scan status={}, scan qr_code={}'.format(contact, status.name, qr_code))
         # print(f'user <{contact}> scan status: {status.name} , 'f'qr_code: {qr_code}')
 
     async def on_message(self, msg: Message):
@@ -67,7 +62,7 @@ class WechatyChannel(Channel):
         content = msg.text()
         mention_content = await msg.mention_text()  # 返回过滤掉@name后的消息
         match_prefix = self.check_prefix(content, conf().get('single_chat_prefix'))
-        conversation: Union[Room, Contact] = from_contact if room is None else room
+        # conversation: Union[Room, Contact] = from_contact if room is None else room
 
         if room is None and msg.type() == MessageType.MESSAGE_TYPE_TEXT:
             if not msg.is_self() and match_prefix is not None:
@@ -102,21 +97,8 @@ class WechatyChannel(Channel):
                 await voice_file.to_file(silk_file)
                 logger.info("[WX]receive voice file: " + silk_file)
                 # 将文件转成wav格式音频
-                wav_file = silk_file.replace(".slk", ".wav")
-                with open(silk_file, 'rb') as f:
-                    silk_data = f.read()
-                pcm_data = pysilk.decode(silk_data)
-
-                with wave.open(wav_file, 'wb') as wav_data:
-                    wav_data.setnchannels(1)
-                    wav_data.setsampwidth(2)
-                    wav_data.setframerate(24000)
-                    wav_data.writeframes(pcm_data)
-                if os.path.exists(wav_file): 
-                    converter_state = "true" # 转换wav成功
-                else:
-                    converter_state = "false" # 转换wav失败
-                logger.info("[WX]receive voice converter: " + converter_state)
+                wav_file = os.path.splitext(silk_file)[0] + '.wav'
+                sil_to_wav(silk_file, wav_file)
                 # 语音识别为文本
                 query = super().build_voice_to_text(wav_file).content
                 # 交验关键字
@@ -183,21 +165,8 @@ class WechatyChannel(Channel):
                 await voice_file.to_file(silk_file)
                 logger.info("[WX]receive voice file: " + silk_file)
                 # 将文件转成wav格式音频
-                wav_file = silk_file.replace(".slk", ".wav")
-                with open(silk_file, 'rb') as f:
-                    silk_data = f.read()
-                pcm_data = pysilk.decode(silk_data)
-
-                with wave.open(wav_file, 'wb') as wav_data:
-                    wav_data.setnchannels(1)
-                    wav_data.setsampwidth(2)
-                    wav_data.setframerate(24000)
-                    wav_data.writeframes(pcm_data)
-                if os.path.exists(wav_file): 
-                    converter_state = "true" # 转换wav成功
-                else:
-                    converter_state = "false" # 转换wav失败
-                logger.info("[WX]receive voice converter: " + converter_state)
+                wav_file = os.path.splitext(silk_file)[0] + '.wav'
+                sil_to_wav(silk_file, wav_file)
                 # 语音识别为文本
                 query = super().build_voice_to_text(wav_file).content
                 # 校验关键字
@@ -260,21 +229,12 @@ class WechatyChannel(Channel):
             if reply_text:
                 # 转换 mp3 文件为 silk 格式
                 mp3_file = super().build_text_to_voice(reply_text).content
-                silk_file = mp3_file.replace(".mp3", ".silk")
-                # Load the MP3 file
-                audio = AudioSegment.from_file(mp3_file, format="mp3")
-                # Convert to WAV format
-                audio = audio.set_frame_rate(24000).set_channels(1)
-                wav_data = audio.raw_data
-                sample_width = audio.sample_width
-                # Encode to SILK format
-                silk_data = pysilk.encode(wav_data, 24000)
-                # Save the silk file
-                with open(silk_file, "wb") as f:
-                    f.write(silk_data)
+                silk_file = os.path.splitext(mp3_file)[0] + '.sil'
+                voiceLength = mp3_to_sil(mp3_file, silk_file)
                 # 发送语音
                 t = int(time.time())
-                file_box = FileBox.from_file(silk_file, name=str(t) + '.silk')
+                file_box = FileBox.from_file(silk_file, name=str(t) + '.sil')
+                file_box.metadata = {'voiceLength': voiceLength}                
                 await self.send(file_box, reply_user_id)
                 # 清除缓存文件
                 os.remove(mp3_file)
@@ -337,21 +297,12 @@ class WechatyChannel(Channel):
             reply_text = '@' + group_user_name + ' ' + reply_text.strip()
             # 转换 mp3 文件为 silk 格式
             mp3_file = super().build_text_to_voice(reply_text).content
-            silk_file = mp3_file.replace(".mp3", ".silk")
-            # Load the MP3 file
-            audio = AudioSegment.from_file(mp3_file, format="mp3")
-            # Convert to WAV format
-            audio = audio.set_frame_rate(24000).set_channels(1)
-            wav_data = audio.raw_data
-            sample_width = audio.sample_width
-            # Encode to SILK format
-            silk_data = pysilk.encode(wav_data, 24000)
-            # Save the silk file
-            with open(silk_file, "wb") as f:
-                f.write(silk_data)
+            silk_file = os.path.splitext(mp3_file)[0] + '.sil'
+            voiceLength = mp3_to_sil(mp3_file, silk_file)
             # 发送语音
             t = int(time.time())
             file_box = FileBox.from_file(silk_file, name=str(t) + '.silk')
+            file_box.metadata = {'voiceLength': voiceLength}            
             await self.send_group(file_box, group_id)
             # 清除缓存文件
             os.remove(mp3_file)

+ 60 - 52
config.py

@@ -5,71 +5,77 @@ import os
 from common.log import logger
 
 # 将所有可用的配置项写在字典里, 请使用小写字母
-available_setting ={
-    #openai api配置
-    "open_ai_api_key": "", # openai api key
-    "open_ai_api_base": "https://api.openai.com/v1", # openai apibase,当use_azure_chatgpt为true时,需要设置对应的api base
-    "proxy": "", # openai使用的代理
-    "model": "gpt-3.5-turbo", # chatgpt模型, 当use_azure_chatgpt为true时,其名称为Azure上model deployment名称
-    "use_azure_chatgpt": False, # 是否使用azure的chatgpt
-
-    #Bot触发配置
-    "single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复
-    "single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人
-    "group_chat_prefix": ["@bot"], # 群聊时包含该前缀则会触发机器人回复
-    "group_chat_reply_prefix": "", # 群聊时自动回复的前缀
-    "group_chat_keyword": [], # 群聊时包含该关键词则会触发机器人回复
-    "group_at_off": False, # 是否关闭群聊时@bot的触发
-    "group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"], # 开启自动回复的群名称列表
-    "group_name_keyword_white_list": [], # 开启自动回复的群名称关键词列表
-    "group_chat_in_one_session": ["ChatGPT测试群"], # 支持会话上下文共享的群名称
-    "image_create_prefix": ["画", "看", "找"], # 开启图片回复的前缀
-    
-    #chatgpt会话参数
-    "expires_in_seconds": 3600, # 无操作会话的过期时间
-    "character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。", # 人格描述
-    "conversation_max_tokens": 1000, # 支持上下文记忆的最多字符数
-            
-    #chatgpt限流配置
-    "rate_limit_chatgpt": 20, # chatgpt的调用频率限制
-    "rate_limit_dalle": 50, # openai dalle的调用频率限制
-
-
-    #chatgpt api参数 参考https://platform.openai.com/docs/api-reference/chat/create
+available_setting = {
+    # openai api配置
+    "open_ai_api_key": "",  # openai api key
+    # openai apibase,当use_azure_chatgpt为true时,需要设置对应的api base
+    "open_ai_api_base": "https://api.openai.com/v1",
+    "proxy": "",  # openai使用的代理
+    # chatgpt模型, 当use_azure_chatgpt为true时,其名称为Azure上model deployment名称
+    "model": "gpt-3.5-turbo",
+    "use_azure_chatgpt": False,  # 是否使用azure的chatgpt
+
+    # Bot触发配置
+    "single_chat_prefix": ["bot", "@bot"],  # 私聊时文本需要包含该前缀才能触发机器人回复
+    "single_chat_reply_prefix": "[bot] ",  # 私聊时自动回复的前缀,用于区分真人
+    "group_chat_prefix": ["@bot"],  # 群聊时包含该前缀则会触发机器人回复
+    "group_chat_reply_prefix": "",  # 群聊时自动回复的前缀
+    "group_chat_keyword": [],  # 群聊时包含该关键词则会触发机器人回复
+    "group_at_off": False,  # 是否关闭群聊时@bot的触发
+    "group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"],  # 开启自动回复的群名称列表
+    "group_name_keyword_white_list": [],  # 开启自动回复的群名称关键词列表
+    "group_chat_in_one_session": ["ChatGPT测试群"],  # 支持会话上下文共享的群名称
+    "image_create_prefix": ["画", "看", "找"],  # 开启图片回复的前缀
+
+    # chatgpt会话参数
+    "expires_in_seconds": 3600,  # 无操作会话的过期时间
+    "character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。",  # 人格描述
+    "conversation_max_tokens": 1000,  # 支持上下文记忆的最多字符数
+
+    # chatgpt限流配置
+    "rate_limit_chatgpt": 20,  # chatgpt的调用频率限制
+    "rate_limit_dalle": 50,  # openai dalle的调用频率限制
+
+
+    # chatgpt api参数 参考https://platform.openai.com/docs/api-reference/chat/create
     "temperature": 0.9,
     "top_p": 1,
     "frequency_penalty": 0,
     "presence_penalty": 0,
 
-    #语音设置
-    "speech_recognition": False, # 是否开启语音识别
-    "group_speech_recognition": False, # 是否开启群组语音识别
-    "voice_reply_voice": False, # 是否使用语音回复语音,需要设置对应语音合成引擎的api key
-    "voice_to_text": "openai", # 语音识别引擎,支持openai和google
-    "text_to_voice": "baidu", # 语音合成引擎,支持baidu和google
+    # 语音设置
+    "speech_recognition": False,  # 是否开启语音识别
+    "group_speech_recognition": False,  # 是否开启群组语音识别
+    "voice_reply_voice": False,  # 是否使用语音回复语音,需要设置对应语音合成引擎的api key
+    "voice_to_text": "openai",  # 语音识别引擎,支持openai和google
+    "text_to_voice": "baidu",  # 语音合成引擎,支持baidu和google
 
     # baidu api的配置, 使用百度语音识别和语音合成时需要
-    'baidu_app_id': "",
-    'baidu_api_key': "",
-    'baidu_secret_key': "",
+    "baidu_app_id": "",
+    "baidu_api_key": "",
+    "baidu_secret_key": "",
+    # 1536普通话(支持简单的英文识别) 1737英语 1637粤语 1837四川话 1936普通话远场
+    "baidu_dev_pid": "1536",
 
-    #服务时间限制,目前支持itchat
-    "chat_time_module": False, # 是否开启服务时间限制
-    "chat_start_time": "00:00", # 服务开始时间
-    "chat_stop_time": "24:00", # 服务结束时间
+    # 服务时间限制,目前支持itchat
+    "chat_time_module": False,  # 是否开启服务时间限制
+    "chat_start_time": "00:00",  # 服务开始时间
+    "chat_stop_time": "24:00",  # 服务结束时间
 
     # itchat的配置
-    "hot_reload": False, # 是否开启热重载
+    "hot_reload": False,  # 是否开启热重载
 
     # wechaty的配置
-    "wechaty_puppet_service_token": "", # wechaty的token
+    "wechaty_puppet_service_token": "",  # wechaty的token
 
     # chatgpt指令自定义触发词
-    "clear_memory_commands": ['#清除记忆'], # 重置会话指令
+    "clear_memory_commands": ['#清除记忆'],  # 重置会话指令
+    "channel_type": "wx", # 通道类型,支持wx,wxy和terminal
 
 
 }
 
+
 class Config(dict):
     def __getitem__(self, key):
         if key not in available_setting:
@@ -82,15 +88,17 @@ class Config(dict):
         return super().__setitem__(key, value)
 
     def get(self, key, default=None):
-        try :
+        try:
             return self[key]
         except KeyError as e:
             return default
         except Exception as e:
             raise e
-    
+
+
 config = Config()
 
+
 def load_config():
     global config
     config_path = "./config.json"
@@ -109,7 +117,8 @@ def load_config():
     for name, value in os.environ.items():
         name = name.lower()
         if name in available_setting:
-            logger.info("[INIT] override config by environ args: {}={}".format(name, value))
+            logger.info(
+                "[INIT] override config by environ args: {}={}".format(name, value))
             try:
                 config[name] = eval(value)
             except:
@@ -118,9 +127,8 @@ def load_config():
     logger.info("[INIT] load config: {}".format(config))
 
 
-
 def get_root():
-    return os.path.dirname(os.path.abspath( __file__ ))
+    return os.path.dirname(os.path.abspath(__file__))
 
 
 def read_file(path):

+ 13 - 0
requirements.txt

@@ -1,3 +1,16 @@
 itchat-uos==1.5.0.dev0
 openai
 wechaty
+tiktoken
+whisper
+baidu-aip 
+chardet
+ffmpy
+pydub
+pilk
+pysilk
+pysilk-mod
+wave
+SpeechRecognition
+pyttsx3
+gTTS

+ 60 - 0
voice/audio_convert.py

@@ -0,0 +1,60 @@
+import wave
+import pysilk
+from pydub import AudioSegment
+
+
+def get_pcm_from_wav(wav_path):
+    """
+    从 wav 文件中读取 pcm
+
+    :param wav_path: wav 文件路径
+    :returns: pcm 数据
+    """
+    wav = wave.open(wav_path, "rb")
+    return wav.readframes(wav.getnframes())
+
+
+def mp3_to_wav(mp3_path, wav_path):
+    """
+    把mp3格式转成pcm文件
+    """
+    audio = AudioSegment.from_mp3(mp3_path)
+    audio.export(wav_path, format="wav")
+
+
+def pcm_to_silk(pcm_path, silk_path):
+    """
+    wav 文件转成 silk
+    return 声音长度,毫秒
+    """
+    audio = AudioSegment.from_wav(pcm_path)
+    wav_data = audio.raw_data
+    silk_data = pysilk.encode(
+        wav_data, data_rate=audio.frame_rate, sample_rate=audio.frame_rate)
+    with open(silk_path, "wb") as f:
+        f.write(silk_data)
+    return audio.duration_seconds * 1000
+
+
+def mp3_to_sil(mp3_path, silk_path):
+    """
+    mp3 文件转成 silk
+    return 声音长度,毫秒
+    """
+    audio = AudioSegment.from_mp3(mp3_path)
+    wav_data = audio.raw_data
+    silk_data = pysilk.encode(
+        wav_data, data_rate=audio.frame_rate, sample_rate=audio.frame_rate)
+    # Save the silk file
+    with open(silk_path, "wb") as f:
+        f.write(silk_data)
+    return audio.duration_seconds * 1000
+
+
+def sil_to_wav(silk_path, wav_path, rate: int = 24000):
+    """
+    silk 文件转 wav
+    """
+    wav_data = pysilk.decode_file(silk_path, to_wav=True, sample_rate=rate)
+    with open(wav_path, "wb") as f:
+        f.write(wav_data)

+ 38 - 3
voice/baidu/baidu_voice.py

@@ -8,19 +8,53 @@ from bridge.reply import Reply, ReplyType
 from common.log import logger
 from common.tmp_dir import TmpDir
 from voice.voice import Voice
+from voice.audio_convert import get_pcm_from_wav
 from config import conf
+"""
+    百度的语音识别API.
+    dev_pid:
+        - 1936: 普通话远场
+        - 1536:普通话(支持简单的英文识别)
+        - 1537:普通话(纯中文识别)
+        - 1737:英语
+        - 1637:粤语
+        - 1837:四川话
+    要使用本模块, 首先到 yuyin.baidu.com 注册一个开发者账号,
+    之后创建一个新应用, 然后在应用管理的"查看key"中获得 API Key 和 Secret Key
+    填入 config.json 中.
+        baidu_app_id: ''
+        baidu_api_key: ''
+        baidu_secret_key: ''
+        baidu_dev_pid: '1536'
+"""
+
 
 class BaiduVoice(Voice):
     APP_ID = conf().get('baidu_app_id')
     API_KEY = conf().get('baidu_api_key')
     SECRET_KEY = conf().get('baidu_secret_key')
+    DEV_ID = conf().get('baidu_dev_pid')
     client = AipSpeech(APP_ID, API_KEY, SECRET_KEY)
-    
+
     def __init__(self):
         pass
 
     def voiceToText(self, voice_file):
-        pass
+        # 识别本地文件
+        logger.debug('[Baidu] voice file name={}'.format(voice_file))
+        pcm = get_pcm_from_wav(voice_file)
+        res = self.client.asr(pcm, "pcm", 16000, {"dev_pid": self.DEV_ID})
+        if res["err_no"] == 0:
+            logger.info("百度语音识别到了:{}".format(res["result"]))
+            text = "".join(res["result"])
+            reply = Reply(ReplyType.TEXT, text)
+        else:
+            logger.info("百度语音识别出错了: {}".format(res["err_msg"]))
+            if res["err_msg"] == "request pv too much":
+                logger.info("  出现这个原因很可能是你的百度语音服务调用量超出限制,或未开通付费")
+            reply = Reply(ReplyType.ERROR,
+                          "百度语音识别出错了;{0}".format(res["err_msg"]))
+        return reply
 
     def textToVoice(self, text):
         result = self.client.synthesis(text, 'zh', 1, {
@@ -30,7 +64,8 @@ class BaiduVoice(Voice):
             fileName = TmpDir().path() + '语音回复_' + str(int(time.time())) + '.mp3'
             with open(fileName, 'wb') as f:
                 f.write(result)
-            logger.info('[Baidu] textToVoice text={} voice file name={}'.format(text, fileName))
+            logger.info(
+                '[Baidu] textToVoice text={} voice file name={}'.format(text, fileName))
             reply = Reply(ReplyType.VOICE, fileName)
         else:
             logger.error('[Baidu] textToVoice error={}'.format(result))

+ 13 - 12
voice/google/google_voice.py

@@ -3,12 +3,11 @@
 google voice service
 """
 
-import pathlib
-import subprocess
 import time
-from bridge.reply import Reply, ReplyType
 import speech_recognition
 import pyttsx3
+from gtts import gTTS
+from bridge.reply import Reply, ReplyType
 from common.log import logger
 from common.tmp_dir import TmpDir
 from voice.voice import Voice
@@ -28,10 +27,10 @@ class GoogleVoice(Voice):
         self.engine.setProperty('voice', voices[1].id)
 
     def voiceToText(self, voice_file):
-        new_file = voice_file.replace('.mp3', '.wav')
-        subprocess.call('ffmpeg -i ' + voice_file +
-                        ' -acodec pcm_s16le -ac 1 -ar 16000 ' + new_file, shell=True)
-        with speech_recognition.AudioFile(new_file) as source:
+        # new_file = voice_file.replace('.mp3', '.wav')
+        # subprocess.call('ffmpeg -i ' + voice_file +
+        #                 ' -acodec pcm_s16le -ac 1 -ar 16000 ' + new_file, shell=True)
+        with speech_recognition.AudioFile(voice_file) as source:
             audio = self.recognizer.record(source)
         try:
             text = self.recognizer.recognize_google(audio, language='zh-CN')
@@ -46,12 +45,14 @@ class GoogleVoice(Voice):
             return reply
     def textToVoice(self, text):
         try:
-            textFile = TmpDir().path() + '语音回复_' + str(int(time.time())) + '.mp3'
-            self.engine.save_to_file(text, textFile)
-            self.engine.runAndWait()
+            mp3File = TmpDir().path() + '语音回复_' + str(int(time.time())) + '.mp3'
+            # self.engine.save_to_file(text, textFile)
+            # self.engine.runAndWait()
+            tts = gTTS(text=text, lang='zh')
+            tts.save(mp3File)            
             logger.info(
-                '[Google] textToVoice text={} voice file name={}'.format(text, textFile))
-            reply = Reply(ReplyType.VOICE, textFile)
+                '[Google] textToVoice text={} voice file name={}'.format(text, mp3File))
+            reply = Reply(ReplyType.VOICE, mp3File)
         except Exception as e:
             reply = Reply(ReplyType.ERROR, str(e))
         finally: