|
|
@@ -9,6 +9,7 @@ import anthropic
|
|
|
from bot.bot import Bot
|
|
|
from bot.openai.open_ai_image import OpenAIImage
|
|
|
from bot.claudeapi.claude_api_session import ClaudeAPISession
|
|
|
+from bot.chatgpt.chat_gpt_session import ChatGPTSession
|
|
|
from bot.session_manager import SessionManager
|
|
|
from bridge.context import ContextType
|
|
|
from bridge.reply import Reply, ReplyType
|
|
|
@@ -32,7 +33,7 @@ class ClaudeAPIBot(Bot, OpenAIImage):
|
|
|
if proxy:
|
|
|
openai.proxy = proxy
|
|
|
|
|
|
- self.sessions = SessionManager(ClaudeAPISession, model=conf().get("model") or "text-davinci-003")
|
|
|
+ self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "text-davinci-003")
|
|
|
|
|
|
def reply(self, query, context=None):
|
|
|
# acquire reply content
|
|
|
@@ -75,16 +76,17 @@ class ClaudeAPIBot(Bot, OpenAIImage):
|
|
|
reply = Reply(ReplyType.ERROR, retstring)
|
|
|
return reply
|
|
|
|
|
|
- def reply_text(self, session: ClaudeAPISession, retry_count=0):
|
|
|
+ def reply_text(self, session: ChatGPTSession, retry_count=0):
|
|
|
try:
|
|
|
- logger.info("[CLAUDE_API] sendMessage={}".format(str(session)))
|
|
|
+ if session.messages[0].get("role") == "system":
|
|
|
+ system = session.messages[0].get("content")
|
|
|
+ session.messages.pop(0)
|
|
|
+ actual_model = self._model_mapping(conf().get("model"))
|
|
|
response = self.claudeClient.messages.create(
|
|
|
- model=conf().get("model"),
|
|
|
+ model=actual_model,
|
|
|
max_tokens=1024,
|
|
|
# system=conf().get("system"),
|
|
|
- messages=[
|
|
|
- {"role": "user", "content": "{}".format(str(session))}
|
|
|
- ]
|
|
|
+ messages=session.messages
|
|
|
)
|
|
|
# response = openai.Completion.create(prompt=str(session), **self.args)
|
|
|
res_content = response.content[0].text.strip().replace("<|endoftext|>", "")
|
|
|
@@ -123,3 +125,12 @@ class ClaudeAPIBot(Bot, OpenAIImage):
|
|
|
return self.reply_text(session, retry_count + 1)
|
|
|
else:
|
|
|
return result
|
|
|
+
|
|
|
+ def _model_mapping(self, model) -> str:
|
|
|
+ if model == "claude-3-opus":
|
|
|
+ return "claude-3-opus-20240229"
|
|
|
+ elif model == "claude-3-sonnet":
|
|
|
+ return "claude-3-sonnet-20240229"
|
|
|
+ elif model == "claude-3-haiku":
|
|
|
+ return "claude-3-haiku-20240307"
|
|
|
+ return model
|