Sfoglia il codice sorgente

chore: add calc_tokens method on session

lanvent 3 anni fa
parent
commit
2989249e4b

+ 4 - 4
bot/chatgpt/chat_gpt_bot.py

@@ -58,7 +58,7 @@ class ChatGPTBot(Bot,OpenAIImage):
             #     # reply in stream
             #     # reply in stream
             #     return self.reply_text_stream(query, new_query, session_id)
             #     return self.reply_text_stream(query, new_query, session_id)
 
 
-            reply_content = self.reply_text(session, session_id, api_key, 0)
+            reply_content = self.reply_text(session, api_key)
             logger.debug("[CHATGPT] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(session.messages, session_id, reply_content["content"], reply_content["completion_tokens"]))
             logger.debug("[CHATGPT] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(session.messages, session_id, reply_content["content"], reply_content["completion_tokens"]))
             if reply_content['completion_tokens'] == 0 and len(reply_content['content']) > 0:
             if reply_content['completion_tokens'] == 0 and len(reply_content['content']) > 0:
                 reply = Reply(ReplyType.ERROR, reply_content['content'])
                 reply = Reply(ReplyType.ERROR, reply_content['content'])
@@ -94,7 +94,7 @@ class ChatGPTBot(Bot,OpenAIImage):
             "timeout": conf().get('request_timeout', None), #重试超时时间,在这个时间内,将会自动重试
             "timeout": conf().get('request_timeout', None), #重试超时时间,在这个时间内,将会自动重试
         }
         }
 
 
-    def reply_text(self, session:ChatGPTSession, session_id, api_key, retry_count=0) -> dict:
+    def reply_text(self, session:ChatGPTSession, api_key=None, retry_count=0) -> dict:
         '''
         '''
         call openai's ChatCompletion to get the answer
         call openai's ChatCompletion to get the answer
         :param session: a conversation session
         :param session: a conversation session
@@ -133,11 +133,11 @@ class ChatGPTBot(Bot,OpenAIImage):
             else:
             else:
                 logger.warn("[CHATGPT] Exception: {}".format(e))
                 logger.warn("[CHATGPT] Exception: {}".format(e))
                 need_retry = False
                 need_retry = False
-                self.sessions.clear_session(session_id)
+                self.sessions.clear_session(session.session_id)
 
 
             if need_retry:
             if need_retry:
                 logger.warn("[CHATGPT] 第{}次重试".format(retry_count+1))
                 logger.warn("[CHATGPT] 第{}次重试".format(retry_count+1))
-                return self.reply_text(session, session_id, api_key, retry_count+1)
+                return self.reply_text(session, api_key, retry_count+1)
             else:
             else:
                 return result
                 return result
 
 

+ 6 - 3
bot/chatgpt/chat_gpt_session.py

@@ -17,7 +17,7 @@ class ChatGPTSession(Session):
     def discard_exceeding(self, max_tokens, cur_tokens= None):
     def discard_exceeding(self, max_tokens, cur_tokens= None):
         precise = True
         precise = True
         try:
         try:
-            cur_tokens = num_tokens_from_messages(self.messages, self.model)
+            cur_tokens = self.calc_tokens()
         except Exception as e:
         except Exception as e:
             precise = False
             precise = False
             if cur_tokens is None:
             if cur_tokens is None:
@@ -29,7 +29,7 @@ class ChatGPTSession(Session):
             elif len(self.messages) == 2 and self.messages[1]["role"] == "assistant":
             elif len(self.messages) == 2 and self.messages[1]["role"] == "assistant":
                 self.messages.pop(1)
                 self.messages.pop(1)
                 if precise:
                 if precise:
-                    cur_tokens = num_tokens_from_messages(self.messages, self.model)
+                    cur_tokens = self.calc_tokens()
                 else:
                 else:
                     cur_tokens = cur_tokens - max_tokens
                     cur_tokens = cur_tokens - max_tokens
                 break
                 break
@@ -40,11 +40,14 @@ class ChatGPTSession(Session):
                 logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages)))
                 logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages)))
                 break
                 break
             if precise:
             if precise:
-                cur_tokens = num_tokens_from_messages(self.messages, self.model)
+                cur_tokens = self.calc_tokens()
             else:
             else:
                 cur_tokens = cur_tokens - max_tokens
                 cur_tokens = cur_tokens - max_tokens
         return cur_tokens
         return cur_tokens
     
     
+    def calc_tokens(self):
+        return num_tokens_from_messages(self.messages, self.model)
+    
 
 
 # refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
 # refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
 def num_tokens_from_messages(messages, model):
 def num_tokens_from_messages(messages, model):

+ 14 - 14
bot/openai/open_ai_bot.py

@@ -42,11 +42,9 @@ class OpenAIBot(Bot, OpenAIImage):
                     reply = Reply(ReplyType.INFO, '所有人记忆已清除')
                     reply = Reply(ReplyType.INFO, '所有人记忆已清除')
                 else:
                 else:
                     session = self.sessions.session_query(query, session_id)
                     session = self.sessions.session_query(query, session_id)
-                    new_query = str(session)
-                    logger.debug("[OPEN_AI] session query={}".format(new_query))
-
-                    total_tokens, completion_tokens, reply_content = self.reply_text(new_query, session_id, 0)
-                    logger.debug("[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(new_query, session_id, reply_content, completion_tokens))
+                    result = self.reply_text(session)
+                    total_tokens, completion_tokens, reply_content = result['total_tokens'], result['completion_tokens'], result['content']
+                    logger.debug("[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(str(session), session_id, reply_content, completion_tokens))
 
 
                     if total_tokens == 0 :
                     if total_tokens == 0 :
                         reply = Reply(ReplyType.ERROR, reply_content)
                         reply = Reply(ReplyType.ERROR, reply_content)
@@ -63,11 +61,11 @@ class OpenAIBot(Bot, OpenAIImage):
                     reply = Reply(ReplyType.ERROR, retstring)
                     reply = Reply(ReplyType.ERROR, retstring)
                 return reply
                 return reply
 
 
-    def reply_text(self, query, session_id, retry_count=0):
+    def reply_text(self, session:OpenAISession, retry_count=0):
         try:
         try:
             response = openai.Completion.create(
             response = openai.Completion.create(
                 model= conf().get("model") or "text-davinci-003",  # 对话模型的名称
                 model= conf().get("model") or "text-davinci-003",  # 对话模型的名称
-                prompt=query,
+                prompt=str(session),
                 temperature=0.9,  # 值在[0,1]之间,越大表示回复越具有不确定性
                 temperature=0.9,  # 值在[0,1]之间,越大表示回复越具有不确定性
                 max_tokens=1200,  # 回复最大的字符数
                 max_tokens=1200,  # 回复最大的字符数
                 top_p=1,
                 top_p=1,
@@ -79,31 +77,33 @@ class OpenAIBot(Bot, OpenAIImage):
             total_tokens = response["usage"]["total_tokens"]
             total_tokens = response["usage"]["total_tokens"]
             completion_tokens = response["usage"]["completion_tokens"]
             completion_tokens = response["usage"]["completion_tokens"]
             logger.info("[OPEN_AI] reply={}".format(res_content))
             logger.info("[OPEN_AI] reply={}".format(res_content))
-            return total_tokens, completion_tokens, res_content
+            return {"total_tokens": total_tokens,
+                    "completion_tokens": completion_tokens,
+                    "content": res_content}
         except Exception as e:
         except Exception as e:
             need_retry = retry_count < 2
             need_retry = retry_count < 2
-            result = [0,0,"我现在有点累了,等会再来吧"]
+            result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
             if isinstance(e, openai.error.RateLimitError):
             if isinstance(e, openai.error.RateLimitError):
                 logger.warn("[OPEN_AI] RateLimitError: {}".format(e))
                 logger.warn("[OPEN_AI] RateLimitError: {}".format(e))
-                result[2] = "提问太快啦,请休息一下再问我吧"
+                result['content'] = "提问太快啦,请休息一下再问我吧"
                 if need_retry:
                 if need_retry:
                     time.sleep(5)
                     time.sleep(5)
             elif isinstance(e, openai.error.Timeout):
             elif isinstance(e, openai.error.Timeout):
                 logger.warn("[OPEN_AI] Timeout: {}".format(e))
                 logger.warn("[OPEN_AI] Timeout: {}".format(e))
-                result[2] = "我没有收到你的消息"
+                result['content'] = "我没有收到你的消息"
                 if need_retry:
                 if need_retry:
                     time.sleep(5)
                     time.sleep(5)
             elif isinstance(e, openai.error.APIConnectionError):
             elif isinstance(e, openai.error.APIConnectionError):
                 logger.warn("[OPEN_AI] APIConnectionError: {}".format(e))
                 logger.warn("[OPEN_AI] APIConnectionError: {}".format(e))
                 need_retry = False
                 need_retry = False
-                result[2] = "我连接不到你的网络"
+                result['content'] = "我连接不到你的网络"
             else:
             else:
                 logger.warn("[OPEN_AI] Exception: {}".format(e))
                 logger.warn("[OPEN_AI] Exception: {}".format(e))
                 need_retry = False
                 need_retry = False
-                self.sessions.clear_session(session_id)
+                self.sessions.clear_session(session.session_id)
 
 
             if need_retry:
             if need_retry:
                 logger.warn("[OPEN_AI] 第{}次重试".format(retry_count+1))
                 logger.warn("[OPEN_AI] 第{}次重试".format(retry_count+1))
-                return self.reply_text(query, session_id, retry_count+1)
+                return self.reply_text(session, retry_count+1)
             else:
             else:
                 return result
                 return result

+ 5 - 3
bot/openai/open_ai_session.py

@@ -29,7 +29,7 @@ class OpenAISession(Session):
     def discard_exceeding(self, max_tokens, cur_tokens= None):
     def discard_exceeding(self, max_tokens, cur_tokens= None):
         precise = True
         precise = True
         try:
         try:
-            cur_tokens = num_tokens_from_string(str(self), self.model)
+            cur_tokens = self.calc_tokens()
         except Exception as e:
         except Exception as e:
             precise = False
             precise = False
             if cur_tokens is None:
             if cur_tokens is None:
@@ -41,7 +41,7 @@ class OpenAISession(Session):
             elif len(self.messages) == 1 and self.messages[0]["role"] == "assistant":
             elif len(self.messages) == 1 and self.messages[0]["role"] == "assistant":
                 self.messages.pop(0)
                 self.messages.pop(0)
                 if precise:
                 if precise:
-                    cur_tokens = num_tokens_from_string(str(self), self.model)
+                    cur_tokens = self.calc_tokens()
                 else:
                 else:
                     cur_tokens = len(str(self))
                     cur_tokens = len(str(self))
                 break
                 break
@@ -52,11 +52,13 @@ class OpenAISession(Session):
                 logger.debug("max_tokens={}, total_tokens={}, len(conversation)={}".format(max_tokens, cur_tokens, len(self.messages)))
                 logger.debug("max_tokens={}, total_tokens={}, len(conversation)={}".format(max_tokens, cur_tokens, len(self.messages)))
                 break
                 break
             if precise:
             if precise:
-                cur_tokens = num_tokens_from_string(str(self), self.model)
+                cur_tokens = self.calc_tokens()
             else:
             else:
                 cur_tokens = len(str(self))
                 cur_tokens = len(str(self))
         return cur_tokens
         return cur_tokens
     
     
+    def calc_tokens(self):
+        return num_tokens_from_string(str(self), self.model)
 
 
 # refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
 # refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
 def num_tokens_from_string(string: str, model: str) -> int:
 def num_tokens_from_string(string: str, model: str) -> int:

+ 2 - 0
bot/session_manager.py

@@ -31,6 +31,8 @@ class Session(object):
     def discard_exceeding(self, max_tokens=None, cur_tokens=None):
     def discard_exceeding(self, max_tokens=None, cur_tokens=None):
         raise NotImplementedError
         raise NotImplementedError
 
 
+    def calc_tokens(self):
+        raise NotImplementedError
 
 
 
 
 class SessionManager(object):
 class SessionManager(object):