You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

95 lines
3.5 KiB

1 year ago
import openai
from openai import OpenAI
import tiktoken
from config import Config
from colorama import Fore, Style
import time
class AI_assistant:
def __init__(self, cfg: Config):
openai.api_key = cfg.open_ai_key
# openai.proxy = cfg.open_ai_proxy
self._chat_model = cfg.open_ai_chat_model
self._use_stream = cfg.use_stream
self._encoding = tiktoken.encoding_for_model('gpt-4-1106-preview')
self._language = cfg.language
self._temperature = cfg.temperature
self.client = OpenAI(api_key=cfg.open_ai_key)
self.assistant = None
self.thread = None
self.run = None
def check_run(self, thread_id, run_id):
while True:
# Refresh the run object to get the latest status
run = self.client.beta.threads.runs.retrieve(
thread_id=thread_id,
run_id=run_id
)
if run.status == "completed":
print(f"{Fore.GREEN} Run is completed.{Style.RESET_ALL}")
break
elif run.status == "expired":
print(f"{Fore.RED}Run is expired.{Style.RESET_ALL}")
break
else:
print(f"{Fore.YELLOW} OpenAI: Run is not yet completed. Waiting...{run.status} {Style.RESET_ALL}")
time.sleep(3) # Wait for 1 second before checking again
def create_assistant(self, name, instructions, tools,files):
self.assistant = self.client.beta.assistants.create(
name=name,
instructions=instructions,
tools=tools,
model=self._chat_model,
file_ids=files
)
def create_thread(self):
self.thread = self.client.beta.threads.create()
def add_message_to_thread(self, role, content):
self.client.beta.threads.messages.create(
thread_id=self.thread.id,
role=role,
content=content
)
def run_assistant(self, instructions):
self.run = self.client.beta.threads.runs.create(
thread_id=self.thread.id,
assistant_id=self.assistant.id,
instructions=instructions
)
def process_messages(self):
messages = self.client.beta.threads.messages.list(thread_id=self.thread.id)
total_price = 0
ans = ""
for msg in messages.data:
role = msg.role
content = msg.content[0].text.value
if role == "user":
total_price = total_price + self._num_tokens_from_string(content)/1000*0.01
elif role == "assistant":
total_price = total_price + self._num_tokens_from_string(content)/1000*0.03
ans = content
return total_price , ans
def upload_file(self, file_path):
# Upload the file to the thread
if file_path != "":
file = self.client.files.create(
file=open(file_path, "rb"),
purpose = 'assistants'
)
print("File successfully uploaded. File ID :" , file.id)
return file.id
def get_files(self):
lists = self.client.files.list()
files_id = []
for list in lists:
files_id.append(list.id)
return files_id
def delete_all_files(self):
files_id = self.get_files()
for id in files_id:
self.client.files.delete(id=id)
def _num_tokens_from_string(self, string: str) -> int:
"""Returns the number of tokens in a text string."""
num_tokens = len(self._encoding.encode(string))
return num_tokens