|
|
@@ -13,6 +13,9 @@ from bridge.reply import Reply, ReplyType
|
|
|
from common.log import logger
|
|
|
from config import conf, pconf
|
|
|
import threading
|
|
|
+from common import memory, utils
|
|
|
+import base64
|
|
|
+
|
|
|
|
|
|
class LinkAIBot(Bot):
|
|
|
# authentication failed
|
|
|
@@ -21,7 +24,7 @@ class LinkAIBot(Bot):
|
|
|
|
|
|
def __init__(self):
|
|
|
super().__init__()
|
|
|
- self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo")
|
|
|
+ self.sessions = LinkAISessionManager(LinkAISession, model=conf().get("model") or "gpt-3.5-turbo")
|
|
|
self.args = {}
|
|
|
|
|
|
def reply(self, query, context: Context = None) -> Reply:
|
|
|
@@ -61,17 +64,25 @@ class LinkAIBot(Bot):
|
|
|
linkai_api_key = conf().get("linkai_api_key")
|
|
|
|
|
|
session_id = context["session_id"]
|
|
|
+ session_message = self.sessions.session_msg_query(query, session_id)
|
|
|
+ logger.debug(f"[LinkAI] session={session_message}, session_id={session_id}")
|
|
|
+
|
|
|
+ # image process
|
|
|
+ img_cache = memory.USER_IMAGE_CACHE.get(session_id)
|
|
|
+ if img_cache:
|
|
|
+ messages = self._process_image_msg(app_code=app_code, session_id=session_id, query=query, img_cache=img_cache)
|
|
|
+ if messages:
|
|
|
+ session_message = messages
|
|
|
|
|
|
- session = self.sessions.session_query(query, session_id)
|
|
|
model = conf().get("model")
|
|
|
# remove system message
|
|
|
- if session.messages[0].get("role") == "system":
|
|
|
+ if session_message[0].get("role") == "system":
|
|
|
if app_code or model == "wenxin":
|
|
|
- session.messages.pop(0)
|
|
|
+ session_message.pop(0)
|
|
|
|
|
|
body = {
|
|
|
"app_code": app_code,
|
|
|
- "messages": session.messages,
|
|
|
+ "messages": session_message,
|
|
|
"model": model, # 对话模型的名称, 支持 gpt-3.5-turbo, gpt-3.5-turbo-16k, gpt-4, wenxin, xunfei
|
|
|
"temperature": conf().get("temperature"),
|
|
|
"top_p": conf().get("top_p", 1),
|
|
|
@@ -94,7 +105,7 @@ class LinkAIBot(Bot):
|
|
|
reply_content = response["choices"][0]["message"]["content"]
|
|
|
total_tokens = response["usage"]["total_tokens"]
|
|
|
logger.info(f"[LINKAI] reply={reply_content}, total_tokens={total_tokens}")
|
|
|
- self.sessions.session_reply(reply_content, session_id, total_tokens)
|
|
|
+ self.sessions.session_reply(reply_content, session_id, total_tokens, query=query)
|
|
|
|
|
|
agent_suffix = self._fetch_agent_suffix(response)
|
|
|
if agent_suffix:
|
|
|
@@ -130,6 +141,54 @@ class LinkAIBot(Bot):
|
|
|
logger.warn(f"[LINKAI] do retry, times={retry_count}")
|
|
|
return self._chat(query, context, retry_count + 1)
|
|
|
|
|
|
+ def _process_image_msg(self, app_code: str, session_id: str, query:str, img_cache: dict):
|
|
|
+ try:
|
|
|
+ enable_image_input = False
|
|
|
+ app_info = self._fetch_app_info(app_code)
|
|
|
+ if not app_info:
|
|
|
+ logger.debug(f"[LinkAI] not found app, can't process images, app_code={app_code}")
|
|
|
+ return None
|
|
|
+ plugins = app_info.get("data").get("plugins")
|
|
|
+ for plugin in plugins:
|
|
|
+ if plugin.get("input_type") and "IMAGE" in plugin.get("input_type"):
|
|
|
+ enable_image_input = True
|
|
|
+ if not enable_image_input:
|
|
|
+ return
|
|
|
+ msg = img_cache.get("msg")
|
|
|
+ path = img_cache.get("path")
|
|
|
+ msg.prepare()
|
|
|
+ logger.info(f"[LinkAI] query with images, path={path}")
|
|
|
+ messages = self._build_vision_msg(query, path)
|
|
|
+ memory.USER_IMAGE_CACHE[session_id] = None
|
|
|
+ return messages
|
|
|
+ except Exception as e:
|
|
|
+ logger.exception(e)
|
|
|
+
|
|
|
+
|
|
|
+ def _build_vision_msg(self, query: str, path: str):
|
|
|
+ try:
|
|
|
+ suffix = utils.get_path_suffix(path)
|
|
|
+ with open(path, "rb") as file:
|
|
|
+ base64_str = base64.b64encode(file.read()).decode('utf-8')
|
|
|
+ messages = [{
|
|
|
+ "role": "user",
|
|
|
+ "content": [
|
|
|
+ {
|
|
|
+ "type": "text",
|
|
|
+ "text": query
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "type": "image_url",
|
|
|
+ "image_url": {
|
|
|
+ "url": f"data:image/{suffix};base64,{base64_str}"
|
|
|
+ }
|
|
|
+ }
|
|
|
+ ]
|
|
|
+ }]
|
|
|
+ return messages
|
|
|
+ except Exception as e:
|
|
|
+ logger.exception(e)
|
|
|
+
|
|
|
def reply_text(self, session: ChatGPTSession, app_code="", retry_count=0) -> dict:
|
|
|
if retry_count >= 2:
|
|
|
# exit from retry 2 times
|
|
|
@@ -195,6 +254,16 @@ class LinkAIBot(Bot):
|
|
|
logger.warn(f"[LINKAI] do retry, times={retry_count}")
|
|
|
return self.reply_text(session, app_code, retry_count + 1)
|
|
|
|
|
|
+ def _fetch_app_info(self, app_code: str):
|
|
|
+ headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
|
|
|
+ # do http request
|
|
|
+ base_url = conf().get("linkai_api_base", "https://api.link-ai.chat")
|
|
|
+ params = {"app_code": app_code}
|
|
|
+ res = requests.get(url=base_url + "/v1/app/info", params=params, headers=headers, timeout=(5, 10))
|
|
|
+ if res.status_code == 200:
|
|
|
+ return res.json()
|
|
|
+ else:
|
|
|
+ logger.warning(f"[LinkAI] find app info exception, res={res}")
|
|
|
|
|
|
def create_img(self, query, retry_count=0, api_key=None):
|
|
|
try:
|
|
|
@@ -239,6 +308,7 @@ class LinkAIBot(Bot):
|
|
|
except Exception as e:
|
|
|
logger.exception(e)
|
|
|
|
|
|
+
|
|
|
def _fetch_agent_suffix(self, response):
|
|
|
try:
|
|
|
plugin_list = []
|
|
|
@@ -275,4 +345,44 @@ class LinkAIBot(Bot):
|
|
|
reply = Reply(ReplyType.IMAGE_URL, url)
|
|
|
channel.send(reply, context)
|
|
|
except Exception as e:
|
|
|
- logger.error(e)
|
|
|
+ logger.error(e)
|
|
|
+
|
|
|
+
|
|
|
+class LinkAISessionManager(SessionManager):
|
|
|
+ def session_msg_query(self, query, session_id):
|
|
|
+ session = self.build_session(session_id)
|
|
|
+ messages = session.messages + [{"role": "user", "content": query}]
|
|
|
+ return messages
|
|
|
+
|
|
|
+ def session_reply(self, reply, session_id, total_tokens=None, query=None):
|
|
|
+ session = self.build_session(session_id)
|
|
|
+ if query:
|
|
|
+ session.add_query(query)
|
|
|
+ session.add_reply(reply)
|
|
|
+ try:
|
|
|
+ max_tokens = conf().get("conversation_max_tokens", 2500)
|
|
|
+ tokens_cnt = session.discard_exceeding(max_tokens, total_tokens)
|
|
|
+ logger.info(f"[LinkAI] chat history discard, before tokens={total_tokens}, now tokens={tokens_cnt}")
|
|
|
+ except Exception as e:
|
|
|
+ logger.warning("Exception when counting tokens precisely for session: {}".format(str(e)))
|
|
|
+ return session
|
|
|
+
|
|
|
+
|
|
|
+class LinkAISession(ChatGPTSession):
|
|
|
+ def calc_tokens(self):
|
|
|
+ try:
|
|
|
+ cur_tokens = super().calc_tokens()
|
|
|
+ except Exception as e:
|
|
|
+ logger.debug("Exception when counting tokens precisely for query: {}".format(e))
|
|
|
+ cur_tokens = len(str(self.messages))
|
|
|
+ return cur_tokens
|
|
|
+
|
|
|
+ def discard_exceeding(self, max_tokens, cur_tokens=None):
|
|
|
+ cur_tokens = self.calc_tokens()
|
|
|
+ if cur_tokens > max_tokens:
|
|
|
+ for i in range(0, len(self.messages)):
|
|
|
+ if i > 0 and self.messages[i].get("role") == "assistant" and self.messages[i - 1].get("role") == "user":
|
|
|
+ self.messages.pop(i)
|
|
|
+ self.messages.pop(i - 1)
|
|
|
+ return self.calc_tokens()
|
|
|
+ return cur_tokens
|