You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
95 lines
3.5 KiB
95 lines
3.5 KiB
import openai |
|
from openai import OpenAI |
|
import tiktoken |
|
from config import Config |
|
from colorama import Fore, Style |
|
import time |
|
class AI_assistant: |
|
def __init__(self, cfg: Config): |
|
openai.api_key = cfg.open_ai_key |
|
# openai.proxy = cfg.open_ai_proxy |
|
self._chat_model = cfg.open_ai_chat_model |
|
self._use_stream = cfg.use_stream |
|
self._encoding = tiktoken.encoding_for_model('gpt-4-1106-preview') |
|
self._language = cfg.language |
|
self._temperature = cfg.temperature |
|
self.client = OpenAI(api_key=cfg.open_ai_key) |
|
self.assistant = None |
|
self.thread = None |
|
self.run = None |
|
def check_run(self, thread_id, run_id): |
|
while True: |
|
# Refresh the run object to get the latest status |
|
run = self.client.beta.threads.runs.retrieve( |
|
thread_id=thread_id, |
|
run_id=run_id |
|
) |
|
|
|
if run.status == "completed": |
|
print(f"{Fore.GREEN} Run is completed.{Style.RESET_ALL}") |
|
break |
|
elif run.status == "expired": |
|
print(f"{Fore.RED}Run is expired.{Style.RESET_ALL}") |
|
break |
|
else: |
|
print(f"{Fore.YELLOW} OpenAI: Run is not yet completed. Waiting...{run.status} {Style.RESET_ALL}") |
|
time.sleep(3) # Wait for 1 second before checking again |
|
def create_assistant(self, name, instructions, tools,files): |
|
self.assistant = self.client.beta.assistants.create( |
|
name=name, |
|
instructions=instructions, |
|
tools=tools, |
|
model=self._chat_model, |
|
file_ids=files |
|
) |
|
def create_thread(self): |
|
self.thread = self.client.beta.threads.create() |
|
def add_message_to_thread(self, role, content): |
|
self.client.beta.threads.messages.create( |
|
thread_id=self.thread.id, |
|
role=role, |
|
content=content |
|
) |
|
def run_assistant(self, instructions): |
|
self.run = self.client.beta.threads.runs.create( |
|
thread_id=self.thread.id, |
|
assistant_id=self.assistant.id, |
|
instructions=instructions |
|
) |
|
def process_messages(self): |
|
messages = self.client.beta.threads.messages.list(thread_id=self.thread.id) |
|
total_price = 0 |
|
ans = "" |
|
for msg in messages.data: |
|
role = msg.role |
|
content = msg.content[0].text.value |
|
if role == "user": |
|
total_price = total_price + self._num_tokens_from_string(content)/1000*0.01 |
|
elif role == "assistant": |
|
total_price = total_price + self._num_tokens_from_string(content)/1000*0.03 |
|
ans = content |
|
return total_price , ans |
|
def upload_file(self, file_path): |
|
# Upload the file to the thread |
|
if file_path != "": |
|
file = self.client.files.create( |
|
file=open(file_path, "rb"), |
|
purpose = 'assistants' |
|
) |
|
print("File successfully uploaded. File ID :" , file.id) |
|
|
|
return file.id |
|
def get_files(self): |
|
lists = self.client.files.list() |
|
files_id = [] |
|
for list in lists: |
|
files_id.append(list.id) |
|
return files_id |
|
def delete_all_files(self): |
|
files_id = self.get_files() |
|
for id in files_id: |
|
self.client.files.delete(id=id) |
|
def _num_tokens_from_string(self, string: str) -> int: |
|
"""Returns the number of tokens in a text string.""" |
|
num_tokens = len(self._encoding.encode(string)) |
|
return num_tokens |