|
|
@@ -1,511 +1,130 @@
|
|
|
-"""
|
|
|
-A simple wrapper for the official ChatGPT API
|
|
|
-"""
|
|
|
-import argparse
|
|
|
-import json
|
|
|
-import os
|
|
|
-import sys
|
|
|
-from datetime import date
|
|
|
-
|
|
|
-import openai
|
|
|
-import tiktoken
|
|
|
+# encoding:utf-8
|
|
|
|
|
|
from bot.bot import Bot
|
|
|
from config import conf
|
|
|
+from common.log import logger
|
|
|
+import openai
|
|
|
+import time
|
|
|
|
|
|
-ENGINE = os.environ.get("GPT_ENGINE") or "text-chat-davinci-002-20221122"
|
|
|
-
|
|
|
-ENCODER = tiktoken.get_encoding("gpt2")
|
|
|
-
|
|
|
-
|
|
|
-def get_max_tokens(prompt: str) -> int:
|
|
|
- """
|
|
|
- Get the max tokens for a prompt
|
|
|
- """
|
|
|
- return 4000 - len(ENCODER.encode(prompt))
|
|
|
-
|
|
|
-
|
|
|
-# ['text-chat-davinci-002-20221122']
|
|
|
-class Chatbot:
|
|
|
- """
|
|
|
- Official ChatGPT API
|
|
|
- """
|
|
|
-
|
|
|
- def __init__(self, api_key: str, buffer: int = None) -> None:
|
|
|
- """
|
|
|
- Initialize Chatbot with API key (from https://platform.openai.com/account/api-keys)
|
|
|
- """
|
|
|
- openai.api_key = api_key or os.environ.get("OPENAI_API_KEY")
|
|
|
- self.conversations = Conversation()
|
|
|
- self.prompt = Prompt(buffer=buffer)
|
|
|
-
|
|
|
- def _get_completion(
|
|
|
- self,
|
|
|
- prompt: str,
|
|
|
- temperature: float = 0.5,
|
|
|
- stream: bool = False,
|
|
|
- ):
|
|
|
- """
|
|
|
- Get the completion function
|
|
|
- """
|
|
|
- return openai.Completion.create(
|
|
|
- engine=ENGINE,
|
|
|
- prompt=prompt,
|
|
|
- temperature=temperature,
|
|
|
- max_tokens=get_max_tokens(prompt),
|
|
|
- stop=["\n\n\n"],
|
|
|
- stream=stream,
|
|
|
- )
|
|
|
-
|
|
|
- def _process_completion(
|
|
|
- self,
|
|
|
- user_request: str,
|
|
|
- completion: dict,
|
|
|
- conversation_id: str = None,
|
|
|
- user: str = "User",
|
|
|
- ) -> dict:
|
|
|
- if completion.get("choices") is None:
|
|
|
- raise Exception("ChatGPT API returned no choices")
|
|
|
- if len(completion["choices"]) == 0:
|
|
|
- raise Exception("ChatGPT API returned no choices")
|
|
|
- if completion["choices"][0].get("text") is None:
|
|
|
- raise Exception("ChatGPT API returned no text")
|
|
|
- completion["choices"][0]["text"] = completion["choices"][0]["text"].rstrip(
|
|
|
- "<|im_end|>",
|
|
|
- )
|
|
|
- # Add to chat history
|
|
|
- self.prompt.add_to_history(
|
|
|
- user_request,
|
|
|
- completion["choices"][0]["text"],
|
|
|
- user=user,
|
|
|
- )
|
|
|
- if conversation_id is not None:
|
|
|
- self.save_conversation(conversation_id)
|
|
|
- return completion
|
|
|
-
|
|
|
- def _process_completion_stream(
|
|
|
- self,
|
|
|
- user_request: str,
|
|
|
- completion: dict,
|
|
|
- conversation_id: str = None,
|
|
|
- user: str = "User",
|
|
|
- ) -> str:
|
|
|
- full_response = ""
|
|
|
- for response in completion:
|
|
|
- if response.get("choices") is None:
|
|
|
- raise Exception("ChatGPT API returned no choices")
|
|
|
- if len(response["choices"]) == 0:
|
|
|
- raise Exception("ChatGPT API returned no choices")
|
|
|
- if response["choices"][0].get("finish_details") is not None:
|
|
|
- break
|
|
|
- if response["choices"][0].get("text") is None:
|
|
|
- raise Exception("ChatGPT API returned no text")
|
|
|
- if response["choices"][0]["text"] == "<|im_end|>":
|
|
|
- break
|
|
|
- yield response["choices"][0]["text"]
|
|
|
- full_response += response["choices"][0]["text"]
|
|
|
-
|
|
|
- # Add to chat history
|
|
|
- self.prompt.add_to_history(user_request, full_response, user)
|
|
|
- if conversation_id is not None:
|
|
|
- self.save_conversation(conversation_id)
|
|
|
-
|
|
|
- def ask(
|
|
|
- self,
|
|
|
- user_request: str,
|
|
|
- temperature: float = 0.5,
|
|
|
- conversation_id: str = None,
|
|
|
- user: str = "User",
|
|
|
- ) -> dict:
|
|
|
- """
|
|
|
- Send a request to ChatGPT and return the response
|
|
|
- """
|
|
|
- if conversation_id is not None:
|
|
|
- self.load_conversation(conversation_id)
|
|
|
- completion = self._get_completion(
|
|
|
- self.prompt.construct_prompt(user_request, user=user),
|
|
|
- temperature,
|
|
|
- )
|
|
|
- return self._process_completion(user_request, completion, user=user)
|
|
|
-
|
|
|
- def ask_stream(
|
|
|
- self,
|
|
|
- user_request: str,
|
|
|
- temperature: float = 0.5,
|
|
|
- conversation_id: str = None,
|
|
|
- user: str = "User",
|
|
|
- ) -> str:
|
|
|
- """
|
|
|
- Send a request to ChatGPT and yield the response
|
|
|
- """
|
|
|
- if conversation_id is not None:
|
|
|
- self.load_conversation(conversation_id)
|
|
|
- prompt = self.prompt.construct_prompt(user_request, user=user)
|
|
|
- return self._process_completion_stream(
|
|
|
- user_request=user_request,
|
|
|
- completion=self._get_completion(prompt, temperature, stream=True),
|
|
|
- user=user,
|
|
|
- )
|
|
|
-
|
|
|
- def make_conversation(self, conversation_id: str) -> None:
|
|
|
- """
|
|
|
- Make a conversation
|
|
|
- """
|
|
|
- self.conversations.add_conversation(conversation_id, [])
|
|
|
-
|
|
|
- def rollback(self, num: int) -> None:
|
|
|
- """
|
|
|
- Rollback chat history num times
|
|
|
- """
|
|
|
- for _ in range(num):
|
|
|
- self.prompt.chat_history.pop()
|
|
|
-
|
|
|
- def reset(self) -> None:
|
|
|
- """
|
|
|
- Reset chat history
|
|
|
- """
|
|
|
- self.prompt.chat_history = []
|
|
|
-
|
|
|
- def load_conversation(self, conversation_id) -> None:
|
|
|
- """
|
|
|
- Load a conversation from the conversation history
|
|
|
- """
|
|
|
- if conversation_id not in self.conversations.conversations:
|
|
|
- # Create a new conversation
|
|
|
- self.make_conversation(conversation_id)
|
|
|
- self.prompt.chat_history = self.conversations.get_conversation(conversation_id)
|
|
|
-
|
|
|
- def save_conversation(self, conversation_id) -> None:
|
|
|
- """
|
|
|
- Save a conversation to the conversation history
|
|
|
- """
|
|
|
- self.conversations.add_conversation(conversation_id, self.prompt.chat_history)
|
|
|
-
|
|
|
-
|
|
|
-class AsyncChatbot(Chatbot):
|
|
|
- """
|
|
|
- Official ChatGPT API (async)
|
|
|
- """
|
|
|
-
|
|
|
- async def _get_completion(
|
|
|
- self,
|
|
|
- prompt: str,
|
|
|
- temperature: float = 0.5,
|
|
|
- stream: bool = False,
|
|
|
- ):
|
|
|
- """
|
|
|
- Get the completion function
|
|
|
- """
|
|
|
- return openai.Completion.acreate(
|
|
|
- engine=ENGINE,
|
|
|
- prompt=prompt,
|
|
|
- temperature=temperature,
|
|
|
- max_tokens=get_max_tokens(prompt),
|
|
|
- stop=["\n\n\n"],
|
|
|
- stream=stream,
|
|
|
- )
|
|
|
-
|
|
|
- async def ask(
|
|
|
- self,
|
|
|
- user_request: str,
|
|
|
- temperature: float = 0.5,
|
|
|
- user: str = "User",
|
|
|
- ) -> dict:
|
|
|
- """
|
|
|
- Same as Chatbot.ask but async
|
|
|
- }
|
|
|
- """
|
|
|
- completion = await self._get_completion(
|
|
|
- self.prompt.construct_prompt(user_request, user=user),
|
|
|
- temperature,
|
|
|
- )
|
|
|
- return self._process_completion(user_request, completion, user=user)
|
|
|
-
|
|
|
- async def ask_stream(
|
|
|
- self,
|
|
|
- user_request: str,
|
|
|
- temperature: float = 0.5,
|
|
|
- user: str = "User",
|
|
|
- ) -> str:
|
|
|
- """
|
|
|
- Same as Chatbot.ask_stream but async
|
|
|
- """
|
|
|
- prompt = self.prompt.construct_prompt(user_request, user=user)
|
|
|
- return self._process_completion_stream(
|
|
|
- user_request=user_request,
|
|
|
- completion=await self._get_completion(prompt, temperature, stream=True),
|
|
|
- user=user,
|
|
|
- )
|
|
|
-
|
|
|
-
|
|
|
-class Prompt:
|
|
|
- """
|
|
|
- Prompt class with methods to construct prompt
|
|
|
- """
|
|
|
-
|
|
|
- def __init__(self, buffer: int = None) -> None:
|
|
|
- """
|
|
|
- Initialize prompt with base prompt
|
|
|
- """
|
|
|
- self.base_prompt = (
|
|
|
- os.environ.get("CUSTOM_BASE_PROMPT")
|
|
|
- or "You are ChatGPT, a large language model trained by OpenAI. Respond conversationally. Do not answer as the user. Current date: "
|
|
|
- + str(date.today())
|
|
|
- + "\n\n"
|
|
|
- + "User: Hello\n"
|
|
|
- + "ChatGPT: Hello! How can I help you today? <|im_end|>\n\n\n"
|
|
|
- )
|
|
|
- # Track chat history
|
|
|
- self.chat_history: list = []
|
|
|
- self.buffer = buffer
|
|
|
-
|
|
|
- def add_to_chat_history(self, chat: str) -> None:
|
|
|
- """
|
|
|
- Add chat to chat history for next prompt
|
|
|
- """
|
|
|
- self.chat_history.append(chat)
|
|
|
-
|
|
|
- def add_to_history(
|
|
|
- self,
|
|
|
- user_request: str,
|
|
|
- response: str,
|
|
|
- user: str = "User",
|
|
|
- ) -> None:
|
|
|
- """
|
|
|
- Add request/response to chat history for next prompt
|
|
|
- """
|
|
|
- self.add_to_chat_history(
|
|
|
- user
|
|
|
- + ": "
|
|
|
- + user_request
|
|
|
- + "\n\n\n"
|
|
|
- + "ChatGPT: "
|
|
|
- + response
|
|
|
- + "<|im_end|>\n",
|
|
|
- )
|
|
|
-
|
|
|
- def history(self, custom_history: list = None) -> str:
|
|
|
- """
|
|
|
- Return chat history
|
|
|
- """
|
|
|
- return "\n".join(custom_history or self.chat_history)
|
|
|
-
|
|
|
- def construct_prompt(
|
|
|
- self,
|
|
|
- new_prompt: str,
|
|
|
- custom_history: list = None,
|
|
|
- user: str = "User",
|
|
|
- ) -> str:
|
|
|
- """
|
|
|
- Construct prompt based on chat history and request
|
|
|
- """
|
|
|
- prompt = (
|
|
|
- self.base_prompt
|
|
|
- + self.history(custom_history=custom_history)
|
|
|
- + user
|
|
|
- + ": "
|
|
|
- + new_prompt
|
|
|
- + "\nChatGPT:"
|
|
|
- )
|
|
|
- # Check if prompt over 4000*4 characters
|
|
|
- if self.buffer is not None:
|
|
|
- max_tokens = 4000 - self.buffer
|
|
|
- else:
|
|
|
- max_tokens = 3200
|
|
|
- if len(ENCODER.encode(prompt)) > max_tokens:
|
|
|
- # Remove oldest chat
|
|
|
- if len(self.chat_history) == 0:
|
|
|
- return prompt
|
|
|
- self.chat_history.pop(0)
|
|
|
- # Construct prompt again
|
|
|
- prompt = self.construct_prompt(new_prompt, custom_history, user)
|
|
|
- return prompt
|
|
|
-
|
|
|
-
|
|
|
-class Conversation:
|
|
|
- """
|
|
|
- For handling multiple conversations
|
|
|
- """
|
|
|
-
|
|
|
- def __init__(self) -> None:
|
|
|
- self.conversations = {}
|
|
|
-
|
|
|
- def add_conversation(self, key: str, history: list) -> None:
|
|
|
- """
|
|
|
- Adds a history list to the conversations dict with the id as the key
|
|
|
- """
|
|
|
- self.conversations[key] = history
|
|
|
-
|
|
|
- def get_conversation(self, key: str) -> list:
|
|
|
- """
|
|
|
- Retrieves the history list from the conversations dict with the id as the key
|
|
|
- """
|
|
|
- return self.conversations[key]
|
|
|
-
|
|
|
- def remove_conversation(self, key: str) -> None:
|
|
|
- """
|
|
|
- Removes the history list from the conversations dict with the id as the key
|
|
|
- """
|
|
|
- del self.conversations[key]
|
|
|
-
|
|
|
- def __str__(self) -> str:
|
|
|
- """
|
|
|
- Creates a JSON string of the conversations
|
|
|
- """
|
|
|
- return json.dumps(self.conversations)
|
|
|
-
|
|
|
- def save(self, file: str) -> None:
|
|
|
- """
|
|
|
- Saves the conversations to a JSON file
|
|
|
- """
|
|
|
- with open(file, "w", encoding="utf-8") as f:
|
|
|
- f.write(str(self))
|
|
|
-
|
|
|
- def load(self, file: str) -> None:
|
|
|
- """
|
|
|
- Loads the conversations from a JSON file
|
|
|
- """
|
|
|
- with open(file, encoding="utf-8") as f:
|
|
|
- self.conversations = json.loads(f.read())
|
|
|
-
|
|
|
+user_session = dict()
|
|
|
|
|
|
-def main():
|
|
|
- print(
|
|
|
- """
|
|
|
- ChatGPT - A command-line interface to OpenAI's ChatGPT (https://chat.openai.com/chat)
|
|
|
- Repo: github.com/acheong08/ChatGPT
|
|
|
- """,
|
|
|
- )
|
|
|
- print("Type '!help' to show a full list of commands")
|
|
|
- print("Press enter twice to submit your question.\n")
|
|
|
+# OpenAI对话模型API (可用)
|
|
|
+class ChatGPTBot(Bot):
|
|
|
+ def __init__(self):
|
|
|
+ openai.api_key = conf().get('open_ai_api_key')
|
|
|
|
|
|
- def get_input(prompt):
|
|
|
- """
|
|
|
- Multi-line input function
|
|
|
- """
|
|
|
- # Display the prompt
|
|
|
- print(prompt, end="")
|
|
|
+ def reply(self, query, context=None):
|
|
|
+ # acquire reply content
|
|
|
+ if not context or not context.get('type') or context.get('type') == 'TEXT':
|
|
|
+ logger.info("[OPEN_AI] query={}".format(query))
|
|
|
+ from_user_id = context['from_user_id']
|
|
|
+ if query == '#清除记忆':
|
|
|
+ Session.clear_session(from_user_id)
|
|
|
+ return '记忆已清除'
|
|
|
|
|
|
- # Initialize an empty list to store the input lines
|
|
|
- lines = []
|
|
|
+ new_query = Session.build_session_query(query, from_user_id)
|
|
|
+ logger.debug("[OPEN_AI] session query={}".format(new_query))
|
|
|
|
|
|
- # Read lines of input until the user enters an empty line
|
|
|
- while True:
|
|
|
- line = input()
|
|
|
- if line == "":
|
|
|
- break
|
|
|
- lines.append(line)
|
|
|
+ # if context.get('stream'):
|
|
|
+ # # reply in stream
|
|
|
+ # return self.reply_text_stream(query, new_query, from_user_id)
|
|
|
|
|
|
- # Join the lines, separated by newlines, and store the result
|
|
|
- user_input = "\n".join(lines)
|
|
|
+ reply_content = self.reply_text(new_query, from_user_id, 0)
|
|
|
+ logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content))
|
|
|
+ if reply_content:
|
|
|
+ Session.save_session(query, reply_content, from_user_id)
|
|
|
+ return reply_content
|
|
|
|
|
|
- # Return the input
|
|
|
- return user_input
|
|
|
+ elif context.get('type', None) == 'IMAGE_CREATE':
|
|
|
+ return self.create_img(query, 0)
|
|
|
|
|
|
- def chatbot_commands(cmd: str) -> bool:
|
|
|
- """
|
|
|
- Handle chatbot commands
|
|
|
- """
|
|
|
- if cmd == "!help":
|
|
|
- print(
|
|
|
- """
|
|
|
- !help - Display this message
|
|
|
- !rollback - Rollback chat history
|
|
|
- !reset - Reset chat history
|
|
|
- !prompt - Show current prompt
|
|
|
- !save_c <conversation_name> - Save history to a conversation
|
|
|
- !load_c <conversation_name> - Load history from a conversation
|
|
|
- !save_f <file_name> - Save all conversations to a file
|
|
|
- !load_f <file_name> - Load all conversations from a file
|
|
|
- !exit - Quit chat
|
|
|
- """,
|
|
|
+ def reply_text(self, query, user_id, retry_count=0):
|
|
|
+ try:
|
|
|
+ response = openai.ChatCompletion.create(
|
|
|
+ model="gpt-3.5-turbo", # 对话模型的名称
|
|
|
+ messages=query,
|
|
|
+ temperature=0.9, # 值在[0,1]之间,越大表示回复越具有不确定性
|
|
|
+ max_tokens=1200, # 回复最大的字符数
|
|
|
+ top_p=1,
|
|
|
+ frequency_penalty=0.0, # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
|
|
+ presence_penalty=0.0, # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
|
|
)
|
|
|
- elif cmd == "!exit":
|
|
|
- exit()
|
|
|
- elif cmd == "!rollback":
|
|
|
- chatbot.rollback(1)
|
|
|
- elif cmd == "!reset":
|
|
|
- chatbot.reset()
|
|
|
- elif cmd == "!prompt":
|
|
|
- print(chatbot.prompt.construct_prompt(""))
|
|
|
- elif cmd.startswith("!save_c"):
|
|
|
- chatbot.save_conversation(cmd.split(" ")[1])
|
|
|
- elif cmd.startswith("!load_c"):
|
|
|
- chatbot.load_conversation(cmd.split(" ")[1])
|
|
|
- elif cmd.startswith("!save_f"):
|
|
|
- chatbot.conversations.save(cmd.split(" ")[1])
|
|
|
- elif cmd.startswith("!load_f"):
|
|
|
- chatbot.conversations.load(cmd.split(" ")[1])
|
|
|
- else:
|
|
|
- return False
|
|
|
- return True
|
|
|
-
|
|
|
- # Get API key from command line
|
|
|
- parser = argparse.ArgumentParser()
|
|
|
- parser.add_argument(
|
|
|
- "--api_key",
|
|
|
- type=str,
|
|
|
- required=True,
|
|
|
- help="OpenAI API key",
|
|
|
- )
|
|
|
- parser.add_argument(
|
|
|
- "--stream",
|
|
|
- action="store_true",
|
|
|
- help="Stream response",
|
|
|
- )
|
|
|
- parser.add_argument(
|
|
|
- "--temperature",
|
|
|
- type=float,
|
|
|
- default=0.5,
|
|
|
- help="Temperature for response",
|
|
|
- )
|
|
|
- args = parser.parse_args()
|
|
|
- # Initialize chatbot
|
|
|
- chatbot = Chatbot(api_key=args.api_key)
|
|
|
- # Start chat
|
|
|
- while True:
|
|
|
+ # res_content = response.choices[0]['text'].strip().replace('<|endoftext|>', '')
|
|
|
+ logger.info(response.choices[0]['message']['content'])
|
|
|
+ # log.info("[OPEN_AI] reply={}".format(res_content))
|
|
|
+ return response.choices[0]['message']['content']
|
|
|
+ except openai.error.RateLimitError as e:
|
|
|
+ # rate limit exception
|
|
|
+ logger.warn(e)
|
|
|
+ if retry_count < 1:
|
|
|
+ time.sleep(5)
|
|
|
+ logger.warn("[OPEN_AI] RateLimit exceed, 第{}次重试".format(retry_count+1))
|
|
|
+ return self.reply_text(query, user_id, retry_count+1)
|
|
|
+ else:
|
|
|
+ return "提问太快啦,请休息一下再问我吧"
|
|
|
+ except Exception as e:
|
|
|
+ # unknown exception
|
|
|
+ logger.exception(e)
|
|
|
+ Session.clear_session(user_id)
|
|
|
+ return "请再问我一次吧"
|
|
|
+
|
|
|
+ def create_img(self, query, retry_count=0):
|
|
|
try:
|
|
|
- prompt = get_input("\nUser:\n")
|
|
|
- except KeyboardInterrupt:
|
|
|
- print("\nExiting...")
|
|
|
- sys.exit()
|
|
|
- if prompt.startswith("!"):
|
|
|
- if chatbot_commands(prompt):
|
|
|
- continue
|
|
|
- if not args.stream:
|
|
|
- response = chatbot.ask(prompt, temperature=args.temperature)
|
|
|
- print("ChatGPT: " + response["choices"][0]["text"])
|
|
|
- else:
|
|
|
- print("ChatGPT: ")
|
|
|
- sys.stdout.flush()
|
|
|
- for response in chatbot.ask_stream(prompt, temperature=args.temperature):
|
|
|
- print(response, end="")
|
|
|
- sys.stdout.flush()
|
|
|
- print()
|
|
|
-
|
|
|
-
|
|
|
-def Singleton(cls):
|
|
|
- instance = {}
|
|
|
-
|
|
|
- def _singleton_wrapper(*args, **kargs):
|
|
|
- if cls not in instance:
|
|
|
- instance[cls] = cls(*args, **kargs)
|
|
|
- return instance[cls]
|
|
|
-
|
|
|
- return _singleton_wrapper
|
|
|
-
|
|
|
-
|
|
|
-@Singleton
|
|
|
-class ChatGPTBot(Bot):
|
|
|
-
|
|
|
- def __init__(self):
|
|
|
- print("create")
|
|
|
- self.bot = Chatbot(conf().get('open_ai_api_key'))
|
|
|
-
|
|
|
- def reply(self, query, context=None):
|
|
|
- if not context or not context.get('type') or context.get('type') == 'TEXT':
|
|
|
- if len(query) < 10 and "reset" in query:
|
|
|
- self.bot.reset()
|
|
|
- return "reset OK"
|
|
|
- return self.bot.ask(query)["choices"][0]["text"]
|
|
|
+ logger.info("[OPEN_AI] image_query={}".format(query))
|
|
|
+ response = openai.Image.create(
|
|
|
+ prompt=query, #图片描述
|
|
|
+ n=1, #每次生成图片的数量
|
|
|
+ size="256x256" #图片大小,可选有 256x256, 512x512, 1024x1024
|
|
|
+ )
|
|
|
+ image_url = response['data'][0]['url']
|
|
|
+ logger.info("[OPEN_AI] image_url={}".format(image_url))
|
|
|
+ return image_url
|
|
|
+ except openai.error.RateLimitError as e:
|
|
|
+ logger.warn(e)
|
|
|
+ if retry_count < 1:
|
|
|
+ time.sleep(5)
|
|
|
+ logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1))
|
|
|
+ return self.reply_text(query, retry_count+1)
|
|
|
+ else:
|
|
|
+ return "提问太快啦,请休息一下再问我吧"
|
|
|
+ except Exception as e:
|
|
|
+ logger.exception(e)
|
|
|
+ return None
|
|
|
+
|
|
|
+class Session(object):
|
|
|
+ @staticmethod
|
|
|
+ def build_session_query(query, user_id):
|
|
|
+ '''
|
|
|
+ build query with conversation history
|
|
|
+ e.g. [
|
|
|
+ {"role": "system", "content": "You are a helpful assistant."},
|
|
|
+ {"role": "user", "content": "Who won the world series in 2020?"},
|
|
|
+ {"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."},
|
|
|
+ {"role": "user", "content": "Where was it played?"}
|
|
|
+ ]
|
|
|
+ :param query: query content
|
|
|
+ :param user_id: from user id
|
|
|
+ :return: query content with conversaction
|
|
|
+ '''
|
|
|
+ session = user_session.get(user_id, [])
|
|
|
+ if len(session) == 0:
|
|
|
+ system_prompt = conf().get("character_desc", "")
|
|
|
+ system_item = {'role': 'system', 'content': system_prompt}
|
|
|
+ session.append(system_item)
|
|
|
+ user_session[user_id] = session
|
|
|
+ user_item = {'role': 'user', 'content': query}
|
|
|
+ session.append(user_item)
|
|
|
+ return session
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def save_session(query, answer, user_id):
|
|
|
+ session = user_session.get(user_id)
|
|
|
+ if session:
|
|
|
+ # append conversation
|
|
|
+ gpt_item = {'role': 'assistant', 'content': answer}
|
|
|
+ session.append(gpt_item)
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def clear_session(user_id):
|
|
|
+ user_session[user_id] = []
|
|
|
|