embed-bge-m3/FlagEmbedding/research/BGE_Coder/data_generation/llm.py

135 lines
4.8 KiB
Python

import os
import time
import openai
import random
import tiktoken
import threading
from openai import OpenAI, AzureOpenAI
from typing import Tuple
class LLM:
def __init__(
self,
model: str="Qwen2-5-Coder-32B-Instruct",
model_type: str = "open-source",
port: int = 8000,
):
if model_type == "open-source":
self.client = OpenAI(
api_key="EMPTY",
base_url=f"http://localhost:{port}/v1/"
)
elif model_type == "azure":
self.client = AzureOpenAI(
api_key=os.getenv("OPENAI_API_KEY"),
api_version=os.getenv("AZURE_API_VERSION", "2024-02-01"),
azure_endpoint=os.getenv("AZURE_ENDPOINT"),
azure_deployment=os.getenv("OPENAI_DEPLOYMENT_NAME", 'gpt-35-turbo')
)
elif model_type == "openai":
self.client = OpenAI(
api_key=os.getenv("OPENAI_API_KEY"),
base_url=os.getenv("OPENAI_BASE_URL", None)
)
else:
raise ValueError("model_type must be one of ['open-source', 'azure', 'openai']")
self.model = model
self.tokenizer = tiktoken.get_encoding("o200k_base")
def split_text(self, text: str, anchor_points: Tuple[float, float] = (0.4, 0.7)):
token_ids = self.tokenizer.encode(text)
anchor_point = random.uniform(anchor_points[0], anchor_points[1])
split_index = int(len(token_ids) * anchor_point)
return self.tokenizer.decode(token_ids[:split_index]), self.tokenizer.decode(token_ids[split_index:])
def chat(
self,
prompt: str,
max_tokens: int = 8192,
logit_bais: dict = None,
n: int = 1,
temperature: float = 1.0,
top_p: float = 0.6,
repetition_penalty: float = 1.0,
remove_thinking: bool = True,
timeout: int = 90,
):
endure_time = 0
endure_time_limit = timeout * 2
def create_completion(results):
try:
completion = self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
max_tokens=max_tokens,
logit_bias=logit_bais if logit_bais is not None else {},
n=n,
temperature=temperature,
top_p=top_p,
extra_body={'repetition_penalty': repetition_penalty},
timeout=timeout,
)
results["content"] = [x.message.content for x in completion.choices[:n]]
except openai.BadRequestError as e:
# The response was filtered due to the prompt triggering Azure OpenAI's content management policy.
results["content"] = [None for _ in range(n)]
except openai.APIConnectionError as e:
results["error"] = f'APIConnectionError({e})'
except openai.RateLimitError as e:
results["error"] = f'RateLimitError({e})'
except Exception as e:
results["error"] = f"Error: {e}"
while True:
results = {"content": None, "error": None}
completion_thread = threading.Thread(target=create_completion, args=(results,))
completion_thread.start()
start_time = time.time()
while completion_thread.is_alive():
elapsed_time = time.time() - start_time
if elapsed_time > endure_time_limit:
print("Completion timeout exceeded. Aborting...")
return [None for _ in range(n)]
time.sleep(1)
# If an error occurred during result processing
if results["error"]:
if endure_time >= endure_time_limit:
print(f'{results["error"]} - Skip this prompt.')
return [None for _ in range(n)]
print(f"{results['error']} - Waiting for 5 seconds...")
endure_time += 5
time.sleep(5)
continue
content_list = results["content"]
if remove_thinking:
content_list = [x.split('</think>')[-1].strip('\n').strip() if x is not None else None for x in content_list]
return content_list
if __name__ == "__main__":
llm = LLM(
model="gpt-4o-mini-2024-07-18",
model_type="openai"
)
prompt = "hello, who are you?"
response = llm.chat(prompt)[0]
print(response)
if __name__ == "__main__":
llm = LLM(
model="gpt-4o-mini-2024-07-18",
model_type="openai"
)
prompt = "hello, who are you?"
response = llm.chat(prompt)[0]
print(response)