Selaa lähdekoodia

feat: image input and session optimize

zhayujie 2 vuotta sitten
vanhempi
säilyke
4e675b84fb
6 muutettua tiedostoa jossa 134 lisäystä ja 16 poistoa
  1. 117 7
      bot/linkai/link_ai_bot.py
  2. 2 2
      bot/session_manager.py
  3. 6 5
      channel/chat_channel.py
  4. 3 0
      common/memory.py
  5. 6 1
      common/utils.py
  6. 0 1
      plugins/linkai/summary.py

+ 117 - 7
bot/linkai/link_ai_bot.py

@@ -13,6 +13,9 @@ from bridge.reply import Reply, ReplyType
 from common.log import logger
 from config import conf, pconf
 import threading
+from common import memory, utils
+import base64
+
 
 class LinkAIBot(Bot):
     # authentication failed
@@ -21,7 +24,7 @@ class LinkAIBot(Bot):
 
     def __init__(self):
         super().__init__()
-        self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo")
+        self.sessions = LinkAISessionManager(LinkAISession, model=conf().get("model") or "gpt-3.5-turbo")
         self.args = {}
 
     def reply(self, query, context: Context = None) -> Reply:
@@ -61,17 +64,25 @@ class LinkAIBot(Bot):
             linkai_api_key = conf().get("linkai_api_key")
 
             session_id = context["session_id"]
+            session_message = self.sessions.session_msg_query(query, session_id)
+            logger.debug(f"[LinkAI] session={session_message}, session_id={session_id}")
+
+            # image process
+            img_cache = memory.USER_IMAGE_CACHE.get(session_id)
+            if img_cache:
+                messages = self._process_image_msg(app_code=app_code, session_id=session_id, query=query, img_cache=img_cache)
+                if messages:
+                    session_message = messages
 
-            session = self.sessions.session_query(query, session_id)
             model = conf().get("model")
             # remove system message
-            if session.messages[0].get("role") == "system":
+            if session_message[0].get("role") == "system":
                 if app_code or model == "wenxin":
-                    session.messages.pop(0)
+                    session_message.pop(0)
 
             body = {
                 "app_code": app_code,
-                "messages": session.messages,
+                "messages": session_message,
                 "model": model,     # 对话模型的名称, 支持 gpt-3.5-turbo, gpt-3.5-turbo-16k, gpt-4, wenxin, xunfei
                 "temperature": conf().get("temperature"),
                 "top_p": conf().get("top_p", 1),
@@ -94,7 +105,7 @@ class LinkAIBot(Bot):
                 reply_content = response["choices"][0]["message"]["content"]
                 total_tokens = response["usage"]["total_tokens"]
                 logger.info(f"[LINKAI] reply={reply_content}, total_tokens={total_tokens}")
-                self.sessions.session_reply(reply_content, session_id, total_tokens)
+                self.sessions.session_reply(reply_content, session_id, total_tokens, query=query)
     
                 agent_suffix = self._fetch_agent_suffix(response)
                 if agent_suffix:
@@ -130,6 +141,54 @@ class LinkAIBot(Bot):
             logger.warn(f"[LINKAI] do retry, times={retry_count}")
             return self._chat(query, context, retry_count + 1)
 
+    def _process_image_msg(self, app_code: str, session_id: str, query:str, img_cache: dict):
+        try:
+            enable_image_input = False
+            app_info = self._fetch_app_info(app_code)
+            if not app_info:
+                logger.debug(f"[LinkAI] not found app, can't process images, app_code={app_code}")
+                return None
+            plugins = app_info.get("data").get("plugins")
+            for plugin in plugins:
+                if plugin.get("input_type") and "IMAGE" in plugin.get("input_type"):
+                    enable_image_input = True
+            if not enable_image_input:
+                return
+            msg = img_cache.get("msg")
+            path = img_cache.get("path")
+            msg.prepare()
+            logger.info(f"[LinkAI] query with images, path={path}")
+            messages = self._build_vision_msg(query, path)
+            memory.USER_IMAGE_CACHE[session_id] = None
+            return messages
+        except Exception as e:
+            logger.exception(e)
+
+
+    def _build_vision_msg(self, query: str, path: str):
+        try:
+            suffix = utils.get_path_suffix(path)
+            with open(path, "rb") as file:
+                base64_str = base64.b64encode(file.read()).decode('utf-8')
+                messages = [{
+                    "role": "user",
+                    "content": [
+                        {
+                            "type": "text",
+                            "text": query
+                        },
+                        {
+                            "type": "image_url",
+                            "image_url": {
+                                "url": f"data:image/{suffix};base64,{base64_str}"
+                            }
+                        }
+                    ]
+                }]
+                return messages
+        except Exception as e:
+            logger.exception(e)
+
     def reply_text(self, session: ChatGPTSession, app_code="", retry_count=0) -> dict:
         if retry_count >= 2:
             # exit from retry 2 times
@@ -195,6 +254,16 @@ class LinkAIBot(Bot):
             logger.warn(f"[LINKAI] do retry, times={retry_count}")
             return self.reply_text(session, app_code, retry_count + 1)
 
+    def _fetch_app_info(self, app_code: str):
+        headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
+        # do http request
+        base_url = conf().get("linkai_api_base", "https://api.link-ai.chat")
+        params = {"app_code": app_code}
+        res = requests.get(url=base_url + "/v1/app/info", params=params, headers=headers, timeout=(5, 10))
+        if res.status_code == 200:
+            return res.json()
+        else:
+            logger.warning(f"[LinkAI] find app info exception, res={res}")
 
     def create_img(self, query, retry_count=0, api_key=None):
         try:
@@ -239,6 +308,7 @@ class LinkAIBot(Bot):
         except Exception as e:
             logger.exception(e)
 
+
     def _fetch_agent_suffix(self, response):
         try:
             plugin_list = []
@@ -275,4 +345,44 @@ class LinkAIBot(Bot):
                 reply = Reply(ReplyType.IMAGE_URL, url)
                 channel.send(reply, context)
         except Exception as e:
-            logger.error(e)
+            logger.error(e)
+
+
+class LinkAISessionManager(SessionManager):
+    def session_msg_query(self, query, session_id):
+        session = self.build_session(session_id)
+        messages = session.messages + [{"role": "user", "content": query}]
+        return messages
+
+    def session_reply(self, reply, session_id, total_tokens=None, query=None):
+        session = self.build_session(session_id)
+        if query:
+            session.add_query(query)
+        session.add_reply(reply)
+        try:
+            max_tokens = conf().get("conversation_max_tokens", 2500)
+            tokens_cnt = session.discard_exceeding(max_tokens, total_tokens)
+            logger.info(f"[LinkAI] chat history discard, before tokens={total_tokens}, now tokens={tokens_cnt}")
+        except Exception as e:
+            logger.warning("Exception when counting tokens precisely for session: {}".format(str(e)))
+        return session
+
+
+class LinkAISession(ChatGPTSession):
+    def calc_tokens(self):
+        try:
+            cur_tokens = super().calc_tokens()
+        except Exception as e:
+            logger.debug("Exception when counting tokens precisely for query: {}".format(e))
+            cur_tokens = len(str(self.messages))
+        return cur_tokens
+
+    def discard_exceeding(self, max_tokens, cur_tokens=None):
+        cur_tokens = self.calc_tokens()
+        if cur_tokens > max_tokens:
+            for i in range(0, len(self.messages)):
+                if i > 0 and self.messages[i].get("role") == "assistant" and self.messages[i - 1].get("role") == "user":
+                    self.messages.pop(i)
+                    self.messages.pop(i - 1)
+                    return self.calc_tokens()
+        return cur_tokens

+ 2 - 2
bot/session_manager.py

@@ -69,7 +69,7 @@ class SessionManager(object):
             total_tokens = session.discard_exceeding(max_tokens, None)
             logger.debug("prompt tokens used={}".format(total_tokens))
         except Exception as e:
-            logger.debug("Exception when counting tokens precisely for prompt: {}".format(str(e)))
+            logger.warning("Exception when counting tokens precisely for prompt: {}".format(str(e)))
         return session
 
     def session_reply(self, reply, session_id, total_tokens=None):
@@ -80,7 +80,7 @@ class SessionManager(object):
             tokens_cnt = session.discard_exceeding(max_tokens, total_tokens)
             logger.debug("raw total_tokens={}, savesession tokens={}".format(total_tokens, tokens_cnt))
         except Exception as e:
-            logger.debug("Exception when counting tokens precisely for session: {}".format(str(e)))
+            logger.warning("Exception when counting tokens precisely for session: {}".format(str(e)))
         return session
 
     def clear_session(self, session_id):

+ 6 - 5
channel/chat_channel.py

@@ -9,8 +9,7 @@ from bridge.context import *
 from bridge.reply import *
 from channel.channel import Channel
 from common.dequeue import Dequeue
-from common.log import logger
-from config import conf
+from common import memory
 from plugins import *
 
 try:
@@ -205,14 +204,16 @@ class ChatChannel(Channel):
                     else:
                         return
             elif context.type == ContextType.IMAGE:  # 图片消息,当前仅做下载保存到本地的逻辑
-                cmsg = context["msg"]
-                cmsg.prepare()
+                memory.USER_IMAGE_CACHE[context["session_id"]] = {
+                    "path": context.content,
+                    "msg": context.get("msg")
+                }
             elif context.type == ContextType.SHARING:  # 分享信息,当前无默认逻辑
                 pass
             elif context.type == ContextType.FUNCTION or context.type == ContextType.FILE:  # 文件消息及函数调用等,当前无默认逻辑
                 pass
             else:
-                logger.error("[WX] unknown context type: {}".format(context.type))
+                logger.warning("[WX] unknown context type: {}".format(context.type))
                 return
         return reply
 

+ 3 - 0
common/memory.py

@@ -0,0 +1,3 @@
+from common.expired_dict import ExpiredDict
+
+USER_IMAGE_CACHE = ExpiredDict(60 * 3)

+ 6 - 1
common/utils.py

@@ -1,6 +1,6 @@
 import io
 import os
-
+from urllib.parse import urlparse
 from PIL import Image
 
 
@@ -49,3 +49,8 @@ def split_string_by_utf8_length(string, max_length, max_split=0):
         result.append(encoded[start:end].decode("utf-8"))
         start = end
     return result
+
+
+def get_path_suffix(path):
+    path = urlparse(path).path
+    return os.path.splitext(path)[-1].lstrip('.')

+ 0 - 1
plugins/linkai/summary.py

@@ -91,5 +91,4 @@ class LinkSummary:
         for support_url in support_list:
             if url.strip().startswith(support_url):
                 return True
-        logger.debug(f"[LinkSum] unsupported url, no need to process, url={url}")
         return False