Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,301 changes: 1,301 additions & 0 deletions research/BGE_Coder/data_generation/constant.py

Large diffs are not rendered by default.

104 changes: 104 additions & 0 deletions research/BGE_Coder/data_generation/corpus_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import os
import random
import datasets
from tqdm import tqdm
from typing import List, Tuple

from utils import clean_code
from constant import DocLength


class CorpusGenerator:
def __init__(
self,
cache_dir: str = None,
):
self.cache_dir = cache_dir

def _load_corpus(self, corpus_dir: str, doc_length: List[str], external_path: List[str],
source_language: str, stop_threshold: int = -1):
"""
Load availavle documents for a given task from the CoIR-Retrieval dataset.
"""

corpus_list = []

if corpus_dir is not None and os.path.exists(corpus_dir):
file_list = os.listdir(corpus_dir)
random.shuffle(file_list)

for file in file_list:
flag = False
if not file.endswith('.jsonl'):
flag = False
for d_length in doc_length:
d_length = DocLength[d_length].value
if d_length in file:
flag = True
if flag is False:
continue
file_path = os.path.join(corpus_dir, file)
corpus = datasets.load_dataset('json', data_files=file_path, cache_dir=self.cache_dir)['train']
for data in tqdm(corpus, desc="Loading corpus"):
if source_language is None:
lang = os.path.basename(corpus_dir)
data['language'] = lang
else:
data['language'] = source_language

text = clean_code(data["text"], data["language"], length_threshold=200)
data["text"] = text
if text != '':
corpus_list.append(data)

if stop_threshold > 0 and len(corpus_list) > stop_threshold:
break
break

for ep in external_path:
if os.path.exists(ep):
corpus = datasets.load_dataset('json', data_files=ep, cache_dir=self.cache_dir)['train']
for data in tqdm(corpus, desc="Loading corpus"):
if source_language is None:
lang = os.path.basename(os.path.dirname(ep))
data['language'] = lang
else:
data['language'] = source_language

# useful when the text is not present in the data
if "text" not in data:
data["text"] = data["pos"][0]

corpus_list.append(data)
text = clean_code(data["text"], lang, length_threshold=200)
data["text"] = text
if text != '':
corpus_list.append(data)

return corpus_list

def run(
self,
num_samples: int = -1,
max_corpus: int = -1,
corpus_dir: str = None,
doc_length: List[str] = ["len_0_500"],
external_path: List[str] = None,
source_language: str = None
) -> Tuple[List[dict], List[dict]]:
stop_threshold = max(num_samples * 10, max_corpus * 2)
corpus_list = self._load_corpus(
corpus_dir, doc_length, external_path, source_language, stop_threshold
)

if num_samples > 0 and num_samples < len(corpus_list):
small_corpus_list = random.sample(corpus_list, num_samples)
else:
small_corpus_list = corpus_list

if max_corpus > 0 and max_corpus < len(corpus_list):
corpus_list = random.sample(corpus_list, max_corpus)
else:
corpus_list = corpus_list

return small_corpus_list, corpus_list
127 changes: 127 additions & 0 deletions research/BGE_Coder/data_generation/format_generated_examples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import os
import json
from constant import Language, CodeLanguage, TaskType, CODE_TRANSLATION_RETRIEVAL_PAIRS, \
get_pos_as_input_by_task_type


def format_generated_examples(
file_path: str,
save_path: str,
task_type: TaskType
):
if os.path.exists(save_path):
return

if not os.path.exists(file_path):
print("====================================")
print("Warning: file not found! Maybe need to generate it first.")
print(f"file_path: {file_path}")
return

pos_as_input = get_pos_as_input_by_task_type(task_type)

data_list = []
with open(file_path, "r", encoding="utf-8") as f:
for line in f.readlines():
data = json.loads(line)

if pos_as_input:
_input = data["pos"][0]
_output = data["query"]
else:
_input = data["query"]
_output = data["pos"][0]

if 'provided' in _input:
continue
if len(_input) > 12000 or len(_output) > 12000:
continue

data_list.append({
"input": _input,
"output": _output
})

if len(data_list) == 0:
print("====================================")
print("Warning: no data found!")
print(f"file_path: {file_path}")
return

os.makedirs(os.path.dirname(save_path), exist_ok=True)
with open(save_path, "w", encoding="utf-8") as f:
json.dump(data_list, f, indent=4, ensure_ascii=False)


def main():
original_gen_examples_dir = "./examples"

formatted_examples_dir = "./filtered_for_generation"

for language in Language:
for task_type in TaskType:
if task_type == TaskType.code_translation_retrieval:
for code_language_pair in CODE_TRANSLATION_RETRIEVAL_PAIRS:
code_language, tgt_code_language = code_language_pair

file_path = os.path.join(
original_gen_examples_dir,
language.name, task_type.name, f"{language.name}-{code_language.name}-to-{tgt_code_language.name}-triplets.jsonl"
)
save_path = os.path.join(
formatted_examples_dir,
language.name, task_type.name, f"{code_language.name}-to-{tgt_code_language.name}_sample_examples.json"
)

format_generated_examples(file_path, save_path, task_type)

for code_language_pair in CODE_TRANSLATION_RETRIEVAL_PAIRS:
tgt_code_language, code_language = code_language_pair

file_path = os.path.join(
original_gen_examples_dir,
language.name, task_type.name, f"{language.name}-{code_language.name}-to-{tgt_code_language.name}-triplets.jsonl"
)
save_path = os.path.join(
formatted_examples_dir,
language.name, task_type.name, f"{code_language.name}-to-{tgt_code_language.name}_sample_examples.json"
)

format_generated_examples(file_path, save_path, task_type)

elif task_type == TaskType.text2sql_retrieval:
file_path = os.path.join(
original_gen_examples_dir,
language.name, task_type.name, f"{language.name}-sql-triplets.jsonl"
)
save_path = os.path.join(
formatted_examples_dir,
language.name, task_type.name, "sql_sample_examples.json"
)

format_generated_examples(file_path, save_path, task_type)

elif task_type == TaskType.code_context_retrieval:
continue

else:
for code_language in CodeLanguage:
if code_language == CodeLanguage.null:
continue

file_path = os.path.join(
original_gen_examples_dir,
language.name, task_type.name, f"{language.name}-{code_language.name}-triplets.jsonl"
)
save_path = os.path.join(
formatted_examples_dir,
language.name, task_type.name, f"{code_language.name}_sample_examples.json"
)

format_generated_examples(file_path, save_path, task_type)

print("All done!")


if __name__ == "__main__":
main()
134 changes: 134 additions & 0 deletions research/BGE_Coder/data_generation/llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import os
import time
import openai
import random
import tiktoken
import threading
from openai import OpenAI, AzureOpenAI
from typing import Tuple


class LLM:
def __init__(
self,
model: str="Qwen2-5-Coder-32B-Instruct",
model_type: str = "open-source",
port: int = 8000,
):
if model_type == "open-source":
self.client = OpenAI(
api_key="EMPTY",
base_url=f"http://localhost:{port}/v1/"
)
elif model_type == "azure":
self.client = AzureOpenAI(
api_key=os.getenv("OPENAI_API_KEY"),
api_version=os.getenv("AZURE_API_VERSION", "2024-02-01"),
azure_endpoint=os.getenv("AZURE_ENDPOINT"),
azure_deployment=os.getenv("OPENAI_DEPLOYMENT_NAME", 'gpt-35-turbo')
)
elif model_type == "openai":
self.client = OpenAI(
api_key=os.getenv("OPENAI_API_KEY"),
base_url=os.getenv("OPENAI_BASE_URL", None)
)
else:
raise ValueError("model_type must be one of ['open-source', 'azure', 'openai']")

self.model = model
self.tokenizer = tiktoken.get_encoding("o200k_base")

def split_text(self, text: str, anchor_points: Tuple[float, float] = (0.4, 0.7)):
token_ids = self.tokenizer.encode(text)
anchor_point = random.uniform(anchor_points[0], anchor_points[1])
split_index = int(len(token_ids) * anchor_point)
return self.tokenizer.decode(token_ids[:split_index]), self.tokenizer.decode(token_ids[split_index:])

def chat(
self,
prompt: str,
max_tokens: int = 8192,
logit_bais: dict = None,
n: int = 1,
temperature: float = 1.0,
top_p: float = 0.6,
repetition_penalty: float = 1.0,
remove_thinking: bool = True,
timeout: int = 90,
):
endure_time = 0
endure_time_limit = timeout * 2

def create_completion(results):
try:
completion = self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
max_tokens=max_tokens,
logit_bias=logit_bais if logit_bais is not None else {},
n=n,
temperature=temperature,
top_p=top_p,
extra_body={'repetition_penalty': repetition_penalty},
timeout=timeout,
)
results["content"] = [x.message.content for x in completion.choices[:n]]
except openai.BadRequestError as e:
# The response was filtered due to the prompt triggering Azure OpenAI's content management policy.
results["content"] = [None for _ in range(n)]
except openai.APIConnectionError as e:
results["error"] = f'APIConnectionError({e})'
except openai.RateLimitError as e:
results["error"] = f'RateLimitError({e})'
except Exception as e:
results["error"] = f"Error: {e}"

while True:
results = {"content": None, "error": None}
completion_thread = threading.Thread(target=create_completion, args=(results,))
completion_thread.start()

start_time = time.time()
while completion_thread.is_alive():
elapsed_time = time.time() - start_time
if elapsed_time > endure_time_limit:
print("Completion timeout exceeded. Aborting...")
return [None for _ in range(n)]
time.sleep(1)

# If an error occurred during result processing
if results["error"]:
if endure_time >= endure_time_limit:
print(f'{results["error"]} - Skip this prompt.')
return [None for _ in range(n)]
print(f"{results['error']} - Waiting for 5 seconds...")
endure_time += 5
time.sleep(5)
continue

content_list = results["content"]
if remove_thinking:
content_list = [x.split('</think>')[-1].strip('\n').strip() if x is not None else None for x in content_list]
return content_list


if __name__ == "__main__":
llm = LLM(
model="gpt-4o-mini-2024-07-18",
model_type="openai"
)

prompt = "hello, who are you?"
response = llm.chat(prompt)[0]
print(response)


if __name__ == "__main__":
llm = LLM(
model="gpt-4o-mini-2024-07-18",
model_type="openai"
)

prompt = "hello, who are you?"
response = llm.chat(prompt)[0]
print(response)
Loading
Loading