You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
95 lines
3.5 KiB
95 lines
3.5 KiB
1 year ago
|
import openai
|
||
|
from openai import OpenAI
|
||
|
import tiktoken
|
||
|
from config import Config
|
||
|
from colorama import Fore, Style
|
||
|
import time
|
||
|
class AI_assistant:
|
||
|
def __init__(self, cfg: Config):
|
||
|
openai.api_key = cfg.open_ai_key
|
||
|
# openai.proxy = cfg.open_ai_proxy
|
||
|
self._chat_model = cfg.open_ai_chat_model
|
||
|
self._use_stream = cfg.use_stream
|
||
|
self._encoding = tiktoken.encoding_for_model('gpt-4-1106-preview')
|
||
|
self._language = cfg.language
|
||
|
self._temperature = cfg.temperature
|
||
|
self.client = OpenAI(api_key=cfg.open_ai_key)
|
||
|
self.assistant = None
|
||
|
self.thread = None
|
||
|
self.run = None
|
||
|
def check_run(self, thread_id, run_id):
|
||
|
while True:
|
||
|
# Refresh the run object to get the latest status
|
||
|
run = self.client.beta.threads.runs.retrieve(
|
||
|
thread_id=thread_id,
|
||
|
run_id=run_id
|
||
|
)
|
||
|
|
||
|
if run.status == "completed":
|
||
|
print(f"{Fore.GREEN} Run is completed.{Style.RESET_ALL}")
|
||
|
break
|
||
|
elif run.status == "expired":
|
||
|
print(f"{Fore.RED}Run is expired.{Style.RESET_ALL}")
|
||
|
break
|
||
|
else:
|
||
|
print(f"{Fore.YELLOW} OpenAI: Run is not yet completed. Waiting...{run.status} {Style.RESET_ALL}")
|
||
|
time.sleep(3) # Wait for 1 second before checking again
|
||
|
def create_assistant(self, name, instructions, tools,files):
|
||
|
self.assistant = self.client.beta.assistants.create(
|
||
|
name=name,
|
||
|
instructions=instructions,
|
||
|
tools=tools,
|
||
|
model=self._chat_model,
|
||
|
file_ids=files
|
||
|
)
|
||
|
def create_thread(self):
|
||
|
self.thread = self.client.beta.threads.create()
|
||
|
def add_message_to_thread(self, role, content):
|
||
|
self.client.beta.threads.messages.create(
|
||
|
thread_id=self.thread.id,
|
||
|
role=role,
|
||
|
content=content
|
||
|
)
|
||
|
def run_assistant(self, instructions):
|
||
|
self.run = self.client.beta.threads.runs.create(
|
||
|
thread_id=self.thread.id,
|
||
|
assistant_id=self.assistant.id,
|
||
|
instructions=instructions
|
||
|
)
|
||
|
def process_messages(self):
|
||
|
messages = self.client.beta.threads.messages.list(thread_id=self.thread.id)
|
||
|
total_price = 0
|
||
|
ans = ""
|
||
|
for msg in messages.data:
|
||
|
role = msg.role
|
||
|
content = msg.content[0].text.value
|
||
|
if role == "user":
|
||
|
total_price = total_price + self._num_tokens_from_string(content)/1000*0.01
|
||
|
elif role == "assistant":
|
||
|
total_price = total_price + self._num_tokens_from_string(content)/1000*0.03
|
||
|
ans = content
|
||
|
return total_price , ans
|
||
|
def upload_file(self, file_path):
|
||
|
# Upload the file to the thread
|
||
|
if file_path != "":
|
||
|
file = self.client.files.create(
|
||
|
file=open(file_path, "rb"),
|
||
|
purpose = 'assistants'
|
||
|
)
|
||
|
print("File successfully uploaded. File ID :" , file.id)
|
||
|
|
||
|
return file.id
|
||
|
def get_files(self):
|
||
|
lists = self.client.files.list()
|
||
|
files_id = []
|
||
|
for list in lists:
|
||
|
files_id.append(list.id)
|
||
|
return files_id
|
||
|
def delete_all_files(self):
|
||
|
files_id = self.get_files()
|
||
|
for id in files_id:
|
||
|
self.client.files.delete(id=id)
|
||
|
def _num_tokens_from_string(self, string: str) -> int:
|
||
|
"""Returns the number of tokens in a text string."""
|
||
|
num_tokens = len(self._encoding.encode(string))
|
||
|
return num_tokens
|