|
@@ -12,13 +12,15 @@ from common import const
|
|
|
class Bridge(object):
|
|
class Bridge(object):
|
|
|
def __init__(self):
|
|
def __init__(self):
|
|
|
self.btype={
|
|
self.btype={
|
|
|
- "chat": "chatGPT",
|
|
|
|
|
|
|
+ "chat": const.CHATGPT,
|
|
|
"voice_to_text": "openai",
|
|
"voice_to_text": "openai",
|
|
|
"text_to_voice": "baidu"
|
|
"text_to_voice": "baidu"
|
|
|
}
|
|
}
|
|
|
|
|
+ model_type = conf().get("model")
|
|
|
|
|
+ if model_type in ["text-davinci-003"]:
|
|
|
|
|
+ self.btype['chat'] = const.OPEN_AI
|
|
|
self.bots={}
|
|
self.bots={}
|
|
|
|
|
|
|
|
-
|
|
|
|
|
def get_bot(self,typename):
|
|
def get_bot(self,typename):
|
|
|
if self.bots.get(typename) is None:
|
|
if self.bots.get(typename) is None:
|
|
|
logger.info("create bot {} for {}".format(self.btype[typename],typename))
|
|
logger.info("create bot {} for {}".format(self.btype[typename],typename))
|
|
@@ -35,13 +37,7 @@ class Bridge(object):
|
|
|
|
|
|
|
|
|
|
|
|
|
def fetch_reply_content(self, query, context : Context) -> Reply:
|
|
def fetch_reply_content(self, query, context : Context) -> Reply:
|
|
|
- bot_type = const.CHATGPT
|
|
|
|
|
- model_type = conf().get("model")
|
|
|
|
|
- if model_type in ["gpt-3.5-turbo", "gpt-4", "gpt-4-32k"]:
|
|
|
|
|
- bot_type = const.CHATGPT
|
|
|
|
|
- elif model_type in ["text-davinci-003"]:
|
|
|
|
|
- bot_type = const.OPEN_AI
|
|
|
|
|
- return bot_factory.create_bot(bot_type).reply(query, context)
|
|
|
|
|
|
|
+ return self.get_bot("chat").reply(query, context)
|
|
|
|
|
|
|
|
|
|
|
|
|
def fetch_voice_to_text(self, voiceFile) -> Reply:
|
|
def fetch_voice_to_text(self, voiceFile) -> Reply:
|