Răsfoiți Sursa

Merge pull request #360 from zwssunny/master

修正会话tokens计算
zhayujie 3 ani în urmă
părinte
comite
2886f48788
2 a modificat fișierele cu 33 adăugiri și 28 ștergeri
  1. 1 0
      README.md
  2. 32 28
      bot/chatgpt/chat_gpt_bot.py

+ 1 - 0
README.md

@@ -142,6 +142,7 @@ touch nohup.out                                   # 首次运行需要新建日
 nohup python3 app.py & tail -f nohup.out          # 在后台运行程序并通过日志输出二维码
 ```
 扫码登录后程序即可运行于服务器后台,此时可通过 `ctrl+c` 关闭日志,不会影响后台程序的运行。使用 `ps -ef | grep app.py | grep -v grep` 命令可查看运行于后台的进程,如果想要重新启动程序可以先 `kill` 掉对应的进程。日志关闭后如果想要再次打开只需输入 `tail -f nohup.out`。
+scripts/目录有相应的脚本可以调用
 
 > **注意:** 如果 扫码后手机提示登录验证需要等待5s,而终端的二维码再次刷新并提示 `Log in time out, reloading QR code`,此时需参考此 [issue](https://github.com/zhayujie/chatgpt-on-wechat/issues/8) 修改一行代码即可解决。
 

+ 32 - 28
bot/chatgpt/chat_gpt_bot.py

@@ -6,7 +6,6 @@ from common.log import logger
 from common.expired_dict import ExpiredDict
 import openai
 import time
-import json
 
 if conf().get('expires_in_seconds'):
     user_session = ExpiredDict(conf().get('expires_in_seconds'))
@@ -41,15 +40,22 @@ class ChatGPTBot(Bot):
             #     return self.reply_text_stream(query, new_query, from_user_id)
 
             reply_content = self.reply_text(new_query, from_user_id, 0)
-            logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content))
-            if reply_content:
-                Session.save_session(query, reply_content, from_user_id)
-            return reply_content
+            logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content["content"]))
+            if reply_content["completion_tokens"] > 0:
+                Session.save_session(reply_content["content"], from_user_id, reply_content["total_tokens"])
+            return reply_content["content"]
 
         elif context.get('type', None) == 'IMAGE_CREATE':
             return self.create_img(query, 0)
 
-    def reply_text(self, query, user_id, retry_count=0):
+    def reply_text(self, query, user_id, retry_count=0) ->dict:
+        '''
+        call openai's ChatCompletion to get the answer
+        :param query: query content
+        :param user_id: from user id
+        :param retry_count: retry count
+        :return: {}
+        '''
         try:
             response = openai.ChatCompletion.create(
                 model="gpt-3.5-turbo",  # 对话模型的名称
@@ -62,8 +68,9 @@ class ChatGPTBot(Bot):
             )
             # res_content = response.choices[0]['text'].strip().replace('<|endoftext|>', '')
             logger.info(response.choices[0]['message']['content'])
-            # log.info("[OPEN_AI] reply={}".format(res_content))
-            return response.choices[0]['message']['content']
+            return {"total_tokens": response["usage"]["total_tokens"], 
+                    "completion_tokens": response["usage"]["completion_tokens"], 
+                    "content": response.choices[0]['message']['content']}
         except openai.error.RateLimitError as e:
             # rate limit exception
             logger.warn(e)
@@ -72,21 +79,21 @@ class ChatGPTBot(Bot):
                 logger.warn("[OPEN_AI] RateLimit exceed, 第{}次重试".format(retry_count+1))
                 return self.reply_text(query, user_id, retry_count+1)
             else:
-                return "提问太快啦,请休息一下再问我吧"
+                return {"completion_tokens": 0, "content": "提问太快啦,请休息一下再问我吧"}
         except openai.error.APIConnectionError as e:
             # api connection exception
             logger.warn(e)
             logger.warn("[OPEN_AI] APIConnection failed")
-            return "我连接不到你的网络"
+            return {"completion_tokens": 0, "content":"我连接不到你的网络"}
         except openai.error.Timeout as e:
             logger.warn(e)
             logger.warn("[OPEN_AI] Timeout")
-            return "我没有收到你的消息"
+            return {"completion_tokens": 0, "content":"我没有收到你的消息"}
         except Exception as e:
             # unknown exception
             logger.exception(e)
             Session.clear_session(user_id)
-            return "请再问我一次吧"
+            return {"completion_tokens": 0, "content": "请再问我一次吧"}
 
     def create_img(self, query, retry_count=0):
         try:
@@ -137,11 +144,12 @@ class Session(object):
         return session
 
     @staticmethod
-    def save_session(query, answer, user_id):
+    def save_session(answer, user_id, total_tokens):
         max_tokens = conf().get("conversation_max_tokens")
         if not max_tokens:
             # default 3000
             max_tokens = 1000
+        max_tokens=int(max_tokens)
 
         session = user_session.get(user_id)
         if session:
@@ -150,23 +158,19 @@ class Session(object):
             session.append(gpt_item)
 
         # discard exceed limit conversation
-        Session.discard_exceed_conversation(user_session[user_id], max_tokens) 
+        Session.discard_exceed_conversation(session, max_tokens, total_tokens)
 
     @staticmethod
-    def discard_exceed_conversation(session, max_tokens):
-        count = 0
-        count_list = list()
-        for i in range(len(session)-1, -1, -1):
-        # count tokens of conversation list
-            history_conv = session[i]
-            tokens=json.dumps(history_conv).split()
-            count += len(tokens)
-            count_list.append(count)
-
-        for c in count_list:
-            if c > max_tokens:
-                # pop first conversation
-                session.pop(0)
+    def discard_exceed_conversation(session, max_tokens, total_tokens):
+        dec_tokens=int(total_tokens)
+        # logger.info("prompt tokens used={},max_tokens={}".format(used_tokens,max_tokens))
+        while dec_tokens > max_tokens:
+            # pop first conversation
+            if len(session) > 0:
+                session.pop(0) 
+            else:
+                break    
+            dec_tokens=dec_tokens-max_tokens
 
     @staticmethod
     def clear_session(user_id):