Jelajahi Sumber

feat: picture auto-generate

zhayujie 3 tahun lalu
induk
melakukan
bf50694658
9 mengubah file dengan 186 tambahan dan 64 penghapusan
  1. 2 1
      .gitignore
  2. 29 13
      README.md
  3. 12 6
      app.py
  4. 1 1
      bot/chatgpt/chat_gpt_bot.py
  5. 58 6
      bot/openai/open_ai_bot.py
  6. 72 24
      channel/wechat/wechat_channel.py
  7. 4 3
      config-template.json
  8. 8 10
      config.py
  9. TEMPAT SAMPAH
      docs/images/image-create-sample.jpg

+ 2 - 1
.gitignore

@@ -1,5 +1,6 @@
 .DS_Store
 .idea
 __pycache__/
-venv
+venv*
 *.pyc
+config.json

+ 29 - 13
README.md

@@ -5,10 +5,11 @@
  
 本项目是基于ChatGPT的微信聊天机器人,通过 [OpenAI](https://github.com/openai/openai-quickstart-python) 接口生成对话内容,使用 [itchat](https://github.com/littlecodersh/ItChat) 实现微信消息的接收和自动回复。已实现的特性如下:
 
-- [x] **基础功能:** 接收私聊及群组中的微信消息,使用ChatGPT生成回复内容,完成自动回复
+- [x] **w文本对话:** 接收私聊及群组中的微信消息,使用ChatGPT生成回复内容,完成自动回复
 - [x] **规则定制化:** 支持私聊中按指定规则触发自动回复,支持对群组设置自动回复白名单
 - [x] **多账号:** 支持多微信账号同时运行
-- [ ] **会话上下文:** 支持用户维度的上下文记忆
+-  [x]  **图片生成:** 支持根据描述生成图片,并自动发送至个人聊天或群聊
+
 
 # 更新
 > **2022.12.17:**  原来的方案是从 [ChatGPT页面](https://chat.openai.com/chat) 获取session_token,使用 [revChatGPT](https://github.com/acheong08/ChatGPT) 直接访问web接口,但随着ChatGPT接入Cloudflare人机验证,这一方案难以在服务器顺利运行。 所以目前使用的方案是调用 OpenAI 官方提供的 [API](https://beta.openai.com/docs/api-reference/introduction),回复质量上基本接近于ChatGPT的内容,劣势是暂不支持有上下文记忆的对话,优势是稳定性和响应速度较好。
@@ -23,6 +24,11 @@
 
 ![group-chat-sample.jpg](docs/images/group-chat-sample.jpg)
 
+### 图片生成
+
+![group-chat-sample.jpg](docs/images/image-create-sample.jpg)
+
+
 # 快速开始
 
 ## 准备
@@ -34,12 +40,12 @@
 
 前往 [OpenAI注册页面](https://beta.openai.com/signup) 创建账号,参考这篇 [博客](https://www.cnblogs.com/damugua/p/16969508.html) 可以通过虚拟手机号来接收验证码。创建完账号则前往 [API管理页面](https://beta.openai.com/account/api-keys) 创建一个 API Key 并保存下来,后面需要在项目中配置这个key。 
 
-> 项目中使用的对话模型是 davinci,计费方式是每1k字 (包含请求和回复) 消耗 $0.02,账号创建有免费的 $18 额度,使用完可以更换邮箱重新注册。
+> 项目中使用的对话模型是 davinci,计费方式是每1k字 (包含请求和回复) 消耗 $0.02,图片生成是每张消耗 $0.016,账号创建有免费的 $18 额度,使用完可以更换邮箱重新注册。
 
 
 ### 3.运行环境
 
-支持运行在 Linux、MacOS、Windows 系统上,需安装有 `Python`(版本在3.7.1 ~ 3.8.16 之间),推荐使用Linux服务器,可以托管在后台长期运行。
+支持运行在 Linux、MacOS、Windows 系统上,且安装有 `Python`(版本需在 3.7.1~3.8.10 之间),推荐使用Linux服务器,可托管于后台长期运行。
 
 克隆项目代码:
 
@@ -50,28 +56,38 @@ https://github.com/zhayujie/chatgpt-on-wechat
 安装所需核心依赖:
 
 ```bash
-pip3 install itchat
-pip3 install openai
+pip3 install itchat==1.3.10      	
+pip3 install openai==0.25.0
 ```
+> 注: 图片生成功能依赖openai 0.25.0版本,要求Python版本在3.7.1以上,而itchat目前只能运行于Python3.9以下版本,故需要Python版本在 3.7.1~3.9 之间,推荐使用 3.7.X 版本。
 
 ## 配置
 
-配置文件在根目录的 `config.json` 中,示例文件及各配置项含义如下:
+配置文件的模板在根目录的`config-template.json`中,需复制该模板创建最终生效的 `config.json` 文件:
+
+```bash
+cp config-template.json config.json
+```
+
+然后填入自定义配置,各配置项含义如下:
 
 ```bash
+# config.json文件内容示例
 { 
-  "open_ai_api_key": "${YOUR API KEY}$"                      # 上面创建的 OpenAI API KEY
-  "single_chat_prefix": ["bot", "@bot"],                     # 私聊时文本需要包含该前缀才能触发机器人回复
-  "single_chat_reply_prefix": "[bot] ",                      # 私聊时自动回复的前缀,用于区分真人
-  "group_chat_prefix": ["@bot"],                             # 群聊时包含该前缀则会触发机器人回复
-  "group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"] # 开启自动回复的群名称列表
+  "open_ai_api_key": "${YOUR API KEY}$"                       # 必填,上面创建的 OpenAI API KEY
+  "single_chat_prefix": ["bot", "@bot"],                      # 私聊时文本需要包含该前缀才能触发机器人回复
+  "single_chat_reply_prefix": "[bot] ",                       # 私聊时自动回复的前缀,用于区分真人
+  "group_chat_prefix": ["@bot"],                              # 群聊时包含该前缀则会触发机器人回复
+  "group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"], # 开启自动回复的群名称列表
+  "image_create_prefix": ["画", "看", "找"]                                                       # 开启图片生成的前缀
 }
 ```
 **配置说明:**
 
 + 个人聊天中,会以 "bot" 或 "@bot" 为开头的内容触发机器人,对应配置中的 `single_chat_prefix`;机器人回复的内容会以 "[bot]" 作为前缀, 以区分真人,对应的配置为 `single_chat_reply_prefix`
 + 群组聊天中,群名称需配置在 `group_name_white_list ` 中才能开启群聊自动回复,默认只要被@就会触发机器人自动回复,另外群聊天中只要检测到以 "@bot" 开头的内容,同样会自动回复,这对应配置 `group_chat_prefix`
-+ 关于OpenAI对话接口的参数配置,可以参考 [接口文档](https://beta.openai.com/docs/api-reference/completions) 直接在代码 `bot\openai\open_ai_bot.py` 中进行调整
++ 对于图像生成,在满足个人或群组触发条件外,还需要额外的关键词,对应配置 `image_create_prefix `
++ 关于OpenAI对话及图片接口的参数配置(内容自由度、回复字数限制、图片大小等),可以参考 [对话接口](https://beta.openai.com/docs/api-reference/completions) 和 [图像接口](https://beta.openai.com/docs/api-reference/completions)  文档直接在 [代码](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/bot/openai/open_ai_bot.py) `bot\openai\open_ai_bot.py` 中进行调整。
 
 
 ## 运行

+ 12 - 6
app.py

@@ -2,13 +2,19 @@
 
 import config
 from channel import channel_factory
+from common.log import logger
+
 
 if __name__ == '__main__':
-    # load config
-    config.load_config()
+    try:
+        # load config
+        config.load_config()
 
-    # create channel
-    channel = channel_factory.create_channel("wx")
+        # create channel
+        channel = channel_factory.create_channel("wx")
 
-    # startup channel
-    channel.startup()
+        # startup channel
+        channel.startup()
+    except Exception as e:
+        logger.error("App startup failed!")
+        logger.exception(e)

+ 1 - 1
bot/chatgpt/chat_gpt_bot.py

@@ -51,5 +51,5 @@ class ChatGPTBot(Bot):
             user_session[from_user_id] = user_cache
             return res['message']
         except Exception as e:
-            logger.error(e)
+            logger.exception(e)
             return None

+ 58 - 6
bot/openai/open_ai_bot.py

@@ -12,23 +12,75 @@ class OpenAIBot(Bot):
         openai.api_key = conf().get('open_ai_api_key')
 
     def reply(self, query, context=None):
+        if not context or not context.get('type') or context.get('type') == 'TEXT':
+            return self.reply_text(query)
+        elif context.get('type', None) == 'IMAGE_CREATE':
+            return self.create_img(query)
+
+    def reply_text(self, query):
         logger.info("[OPEN_AI] query={}".format(query))
         try:
             response = openai.Completion.create(
-                model="text-davinci-003",      #对话模型的名称
+                model="text-davinci-003",  # 对话模型的名称
                 prompt=query,
-                temperature=0.9,               #值在[0,1]之间,越大表示回复越具有不确定性
-                max_tokens=1200,               #回复最大的字符数
+                temperature=0.9,  # 值在[0,1]之间,越大表示回复越具有不确定性
+                max_tokens=1200,  # 回复最大的字符数
                 top_p=1,
-                frequency_penalty=0.0,         #[-2,2]之间,该值越大则更倾向于产生不同的内容
-                presence_penalty=0.6,          #[-2,2]之间,该值越大则更倾向于产生不同的内容
+                frequency_penalty=0.0,  # [-2,2]之间,该值越大则更倾向于产生不同的内容
+                presence_penalty=0.6,  # [-2,2]之间,该值越大则更倾向于产生不同的内容
                 stop=["#"]
             )
             res_content = response.choices[0]["text"].strip()
         except Exception as e:
-            logger.error(e)
+            logger.exception(e)
             return None
         logger.info("[OPEN_AI] reply={}".format(res_content))
         return res_content
 
+    def create_img(self, query):
+        try:
+            logger.info("[OPEN_AI] image_query={}".format(query))
+            response = openai.Image.create(
+                prompt=query,    #图片描述
+                n=1,             #每次生成图片的数量
+                size="256x256"   #图片大小,可选有 256x256, 512x512, 1024x1024
+            )
+            image_url = response['data'][0]['url']
+            logger.info("[OPEN_AI] image_url={}".format(image_url))
+        except Exception as e:
+            logger.exception(e)
+            return None
+        return image_url
+
+    def edit_img(self, query, src_img):
+        openai.api_key = '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'
+        try:
+            response = openai.Image.create_edit(
+                image=open(src_img, 'rb'),
+                mask=open('cat-mask.png', 'rb'),
+                prompt=query,
+                n=1,
+                size='512x512'
+            )
+            image_url = response['data'][0]['url']
+            logger.info("[OPEN_AI] image_url={}".format(image_url))
+        except Exception as e:
+            logger.exception(e)
+            return None
+        return image_url
+
+    def migration_img(self, query, src_img):
+        openai.api_key = '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'
 
+        try:
+            response = openai.Image.create_variation(
+                image=open(src_img, 'rb'),
+                n=1,
+                size="512x512"
+            )
+            image_url = response['data'][0]['url']
+            logger.info("[OPEN_AI] image_url={}".format(image_url))
+        except Exception as e:
+            logger.exception(e)
+            return None
+        return image_url

+ 72 - 24
channel/wechat/wechat_channel.py

@@ -10,6 +10,8 @@ from channel.channel import Channel
 from concurrent.futures import ThreadPoolExecutor
 from common.log import logger
 from config import conf
+import requests
+import io
 
 thead_pool = ThreadPoolExecutor(max_workers=8)
 
@@ -17,11 +19,13 @@ thead_pool = ThreadPoolExecutor(max_workers=8)
 @itchat.msg_register(TEXT)
 def handler_single_msg(msg):
     WechatChannel().handle(msg)
+    return None
 
 
 @itchat.msg_register(TEXT, isGroupChat=True)
 def handler_group_msg(msg):
     WechatChannel().handle_group(msg)
+    return None
 
 
 class WechatChannel(Channel):
@@ -38,26 +42,40 @@ class WechatChannel(Channel):
     def handle(self, msg):
         logger.info("[WX]receive msg: " + json.dumps(msg, ensure_ascii=False))
         from_user_id = msg['FromUserName']
-        to_user_id = msg['ToUserName']
-        other_user_id = msg['User']['UserName']
+        to_user_id = msg['ToUserName']              # 接收人id
+        other_user_id = msg['User']['UserName']     # 对手方id
         content = msg['Text']
-        if from_user_id == other_user_id and \
-                self.check_prefix(content, conf().get('single_chat_prefix')):
-            str_list = content.split('bot', 1)
+        match_prefix = self.check_prefix(content, conf().get('single_chat_prefix'))
+        if from_user_id == other_user_id and match_prefix:
+            # 好友向自己发送消息
+            str_list = content.split(match_prefix, 1)
             if len(str_list) == 2:
                 content = str_list[1].strip()
-            thead_pool.submit(self._do_send, content, from_user_id)
-        elif to_user_id == other_user_id and \
-                self.check_prefix(content, conf().get('single_chat_prefix')):
-            str_list = content.split('bot', 1)
+
+            img_match_prefix = self.check_prefix(content, ["画图"])
+            if img_match_prefix:
+                content = content.split(img_match_prefix, 1)[1].strip()
+                thead_pool.submit(self._do_send_img, content, from_user_id)
+            else:
+                thead_pool.submit(self._do_send, content, from_user_id)
+
+        elif to_user_id == other_user_id and match_prefix:
+            # 自己给好友发送消息
+            str_list = content.split(match_prefix, 1)
             if len(str_list) == 2:
                 content = str_list[1].strip()
-            thead_pool.submit(self._do_send, content, to_user_id)
+            img_match_prefix = self.check_prefix(content, conf().get('image_create_prefix'))
+            if img_match_prefix:
+                content = content.split(img_match_prefix, 1)[1].strip()
+                thead_pool.submit(self._do_send_img, content, to_user_id)
+            else:
+                thead_pool.submit(self._do_send, content, to_user_id)
 
 
     def handle_group(self, msg):
         logger.info("[WX]receive group msg: " + json.dumps(msg, ensure_ascii=False))
         group_name = msg['User'].get('NickName', None)
+        group_id = msg['User'].get('UserName', None)
         if not group_name:
             return ""
         origin_content = msg['Content']
@@ -70,23 +88,53 @@ class WechatChannel(Channel):
             content = content_list[1]
 
         config = conf()
-        if group_name in config.get('group_name_white_list') \
-                and (msg['IsAt'] or self.check_prefix(origin_content, config.get('group_chat_prefix'))):
-            thead_pool.submit(self._do_send_group, content, msg)
+        match_prefix = msg['IsAt'] or self.check_prefix(origin_content, config.get('group_chat_prefix'))
+        if group_name in config.get('group_name_white_list') and match_prefix:
+            img_match_prefix = self.check_prefix(content, conf().get('image_create_prefix'))
+            if img_match_prefix:
+                content = content.split(img_match_prefix, 1)[1].strip()
+                thead_pool.submit(self._do_send_img, content, group_id)
+            else:
+                thead_pool.submit(self._do_send_group, content, msg)
 
     def send(self, msg, receiver):
-        # time.sleep(random.randint(1, 3))
         logger.info('[WX] sendMsg={}, receiver={}'.format(msg, receiver))
         itchat.send(msg, toUserName=receiver)
 
     def _do_send(self, query, reply_user_id):
-        if not query:
-            return
-        context = dict()
-        context['from_user_id'] = reply_user_id
-        reply_text = super().build_reply_content(query, context).strip()
-        if reply_text:
-            self.send(conf().get("single_chat_reply_prefix") + reply_text, reply_user_id)
+        try:
+            if not query:
+                return
+            context = dict()
+            context['from_user_id'] = reply_user_id
+            reply_text = super().build_reply_content(query, context).strip()
+            if reply_text:
+                self.send(conf().get("single_chat_reply_prefix") + reply_text, reply_user_id)
+        except Exception as e:
+            logger.exception(e)
+
+    def _do_send_img(self, query, reply_user_id):
+        try:
+            if not query:
+                return
+            context = dict()
+            context['type'] = 'IMAGE_CREATE'
+            img_url = super().build_reply_content(query, context)
+            if not img_url:
+                return
+
+            # 图片下载
+            pic_res = requests.get(img_url, stream=True)
+            image_storage = io.BytesIO()
+            for block in pic_res.iter_content(1024):
+                image_storage.write(block)
+            image_storage.seek(0)
+
+            # 图片发送
+            logger.info('[WX] sendImage, receiver={}'.format(reply_user_id))
+            itchat.send_image(image_storage, reply_user_id)
+        except Exception as e:
+            logger.exception(e)
 
     def _do_send_group(self, query, msg):
         if not query:
@@ -98,9 +146,9 @@ class WechatChannel(Channel):
         if reply_text:
             self.send(reply_text, msg['User']['UserName'])
 
+
     def check_prefix(self, content, prefix_list):
         for prefix in prefix_list:
             if content.lower().startswith(prefix):
-                return True
-        return False
-
+                return prefix
+        return None

+ 4 - 3
config.json → config-template.json

@@ -1,7 +1,8 @@
 {
   "open_ai_api_key": "${YOUR API KEY}$",
-  "single_chat_prefix": ["bot", "@bot"],
+  "single_chat_prefix": ["bt", "@bt"],
   "single_chat_reply_prefix": "[bot] ",
-  "group_chat_prefix": ["@bot"],
-  "group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"]
+  "group_chat_prefix": ["@bt"],
+  "group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"],
+  "image_create_prefix": ["画", "看", "找"]
 }

+ 8 - 10
config.py

@@ -10,16 +10,14 @@ config = {}
 def load_config():
     global config
     config_path = "config.json"
-    try:
-        if not os.path.exists(config_path):
-            logger.error('配置文件路径不存在')
-            return
-        config_str = read_file(config_path)
-        # 将json字符串反序列化为dict类型
-        config = json.loads(config_str)
-        logger.info("[INIT] load config: {}".format(config))
-    except Exception as e:
-        logger.error(e)
+    if not os.path.exists(config_path):
+        raise Exception('配置文件不存在,请根据config-template.json模板创建config.json文件')
+
+    config_str = read_file(config_path)
+    # 将json字符串反序列化为dict类型
+    config = json.loads(config_str)
+    logger.info("[INIT] load config: {}".format(config))
+
 
 
 def get_root():

TEMPAT SAMPAH
docs/images/image-create-sample.jpg