原始代码
This commit is contained in:
736
optimization/optimize_instructions.py
Normal file
736
optimization/optimize_instructions.py
Normal file
@@ -0,0 +1,736 @@
|
||||
|
||||
import datetime
|
||||
import functools
|
||||
import os
|
||||
import sys
|
||||
|
||||
OPRO_ROOT_PATH = os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
)
|
||||
sys.path.insert(0, OPRO_ROOT_PATH)
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
import google.generativeai as palm
|
||||
import numpy as np
|
||||
import openai
|
||||
from opro import prompt_utils
|
||||
from opro.optimization import opt_utils
|
||||
import pandas as pd
|
||||
FLAGS = flags.FLAGS
|
||||
ROOT_DATA_FOLDER_PATH = os.path.join(OPRO_ROOT_PATH, "data")
|
||||
|
||||
flags.DEFINE_string("local_model_path", "", "Path to local vLLM model.")
|
||||
|
||||
_OPENAI_API_KEY = flags.DEFINE_string(
|
||||
"openai_api_key", "", "The OpenAI API key."
|
||||
)
|
||||
|
||||
_PALM_API_KEY = flags.DEFINE_string("palm_api_key", "", "The PaLM API key.")
|
||||
|
||||
_SCORER = flags.DEFINE_string(
|
||||
"scorer", "text-bison", "The name of the scorer LLM."
|
||||
)
|
||||
|
||||
_OPTIMIZER = flags.DEFINE_string(
|
||||
"optimizer", "gpt-3.5-turbo", "The name of the optimizer LLM."
|
||||
)
|
||||
|
||||
_DATASET = flags.DEFINE_string(
|
||||
"dataset", "gsm8k", "The name of dataset to search for instructions on."
|
||||
)
|
||||
|
||||
_TASK = flags.DEFINE_string(
|
||||
"task",
|
||||
"train",
|
||||
"The name of task within the above dataset to search for instructions on.",
|
||||
)
|
||||
|
||||
_INSTRUCTION_POS = flags.DEFINE_string(
|
||||
"instruction_pos",
|
||||
"A_begin",
|
||||
"The position of the instruction to search for.",
|
||||
)
|
||||
|
||||
_META_PROMPT_TYPE = flags.DEFINE_string(
|
||||
"meta_prompt_type",
|
||||
"both_instructions_and_exemplars",
|
||||
"The type of meta-prompt: whether to have both previous instructions and"
|
||||
" dataset exemplars (often for fine-tuned optimizers), or to have only"
|
||||
" previous instructions (often for pre-trained optimizers).",
|
||||
)
|
||||
|
||||
|
||||
def main(_):
|
||||
local_model_path = FLAGS.local_model_path
|
||||
openai_api_key = _OPENAI_API_KEY.value
|
||||
palm_api_key = _PALM_API_KEY.value
|
||||
scorer_llm_name = _SCORER.value
|
||||
optimizer_llm_name = _OPTIMIZER.value
|
||||
dataset_name = _DATASET.value.lower()
|
||||
task_name = _TASK.value
|
||||
meta_prompt_type = _META_PROMPT_TYPE.value
|
||||
|
||||
assert dataset_name in {
|
||||
"mmlu",
|
||||
"bbh",
|
||||
"gsm8k",
|
||||
}, "The lower-case dataset name must be one of mmlu, bbh, or gsm8k."
|
||||
if dataset_name == "mmlu":
|
||||
assert task_name in {
|
||||
"STEM",
|
||||
"humanities",
|
||||
"social sciences",
|
||||
"other (business, health, misc.)",
|
||||
} # for now only support searching on one MMLU category
|
||||
elif dataset_name == "bbh":
|
||||
assert task_name in {
|
||||
"boolean_expressions",
|
||||
"causal_judgement",
|
||||
"date_understanding",
|
||||
"disambiguation_qa",
|
||||
"dyck_languages",
|
||||
"formal_fallacies",
|
||||
"geometric_shapes",
|
||||
"hyperbaton",
|
||||
"logical_deduction_five_objects",
|
||||
"logical_deduction_seven_objects",
|
||||
"logical_deduction_three_objects",
|
||||
"movie_recommendation",
|
||||
"multistep_arithmetic_two",
|
||||
"navigate",
|
||||
"object_counting",
|
||||
"penguins_in_a_table",
|
||||
"reasoning_about_colored_objects",
|
||||
"ruin_names",
|
||||
"salient_translation_error_detection",
|
||||
"snarks",
|
||||
"sports_understanding",
|
||||
"temporal_sequences",
|
||||
"tracking_shuffled_objects_five_objects",
|
||||
"tracking_shuffled_objects_seven_objects",
|
||||
"tracking_shuffled_objects_three_objects",
|
||||
"web_of_lies",
|
||||
"word_sorting",
|
||||
}
|
||||
else:
|
||||
assert dataset_name == "gsm8k"
|
||||
assert task_name in {"train", "test"}
|
||||
|
||||
assert scorer_llm_name in {
|
||||
"text-bison",
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-4",
|
||||
"local",
|
||||
}
|
||||
assert optimizer_llm_name in {
|
||||
"text-bison",
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-4",
|
||||
"local",
|
||||
}
|
||||
assert meta_prompt_type in {
|
||||
"both_instructions_and_exemplars",
|
||||
"instructions_only",
|
||||
}
|
||||
|
||||
instruction_pos = _INSTRUCTION_POS.value
|
||||
assert instruction_pos in {
|
||||
"before_Q",
|
||||
"Q_begin",
|
||||
"Q_end",
|
||||
"A_begin",
|
||||
}, (
|
||||
"The instruction position should be either before the question, or at the"
|
||||
" beginning of the question, at the end of the question, or at the"
|
||||
" beginning of the answer."
|
||||
)
|
||||
print(
|
||||
f"scorer: {scorer_llm_name}, optimizer: {optimizer_llm_name}, dataset:"
|
||||
f" {dataset_name}, task: {task_name}, instruction_pos: {instruction_pos}"
|
||||
)
|
||||
|
||||
|
||||
if scorer_llm_name in {"gpt-3.5-turbo", "gpt-4"}:
|
||||
assert openai_api_key, "The OpenAI API key must be provided."
|
||||
openai.api_key = openai_api_key
|
||||
elif scorer_llm_name == "text-bison":
|
||||
assert palm_api_key, "A PaLM API key is needed when prompting the text-bison model."
|
||||
palm.configure(api_key=palm_api_key)
|
||||
elif scorer_llm_name == "local":
|
||||
# 本地模型,无需 API key
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"Unknown scorer model: {scorer_llm_name}")
|
||||
|
||||
if optimizer_llm_name in {"gpt-3.5-turbo", "gpt-4"}:
|
||||
assert openai_api_key, "The OpenAI API key must be provided."
|
||||
openai.api_key = openai_api_key
|
||||
elif optimizer_llm_name == "text-bison":
|
||||
assert palm_api_key, "A PaLM API key is needed when prompting the text-bison model."
|
||||
palm.configure(api_key=palm_api_key)
|
||||
elif optimizer_llm_name == "local":
|
||||
# 本地模型,无需 API key
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"Unknown scorer model: {optimizer_llm_name}")
|
||||
|
||||
|
||||
if dataset_name == "mmlu":
|
||||
root_data_folder_path = os.path.join(ROOT_DATA_FOLDER_PATH, "MMLU-data")
|
||||
elif dataset_name == "bbh":
|
||||
root_data_folder_path = os.path.join(
|
||||
ROOT_DATA_FOLDER_PATH, "BIG-Bench-Hard-data/"
|
||||
)
|
||||
else:
|
||||
assert dataset_name == "gsm8k"
|
||||
root_data_folder_path = os.path.join(ROOT_DATA_FOLDER_PATH, "gsm_data")
|
||||
|
||||
# =================== create the result directory ==========================
|
||||
datetime_str = (
|
||||
str(datetime.datetime.now().replace(microsecond=0))
|
||||
.replace(" ", "-")
|
||||
.replace(":", "-")
|
||||
)
|
||||
|
||||
|
||||
save_folder = os.path.join(
|
||||
OPRO_ROOT_PATH,
|
||||
"outputs",
|
||||
"optimization-results",
|
||||
f"{dataset_name.upper()}-{task_name}-s-{scorer_llm_name}-o-{optimizer_llm_name}-{datetime_str}/",
|
||||
)
|
||||
|
||||
|
||||
result_by_instruction_folder = os.path.join(
|
||||
save_folder, "result_by_instruction"
|
||||
)
|
||||
print(f"Results will be saved to: {os.path.abspath(result_by_instruction_folder)}")
|
||||
os.makedirs(result_by_instruction_folder,exist_ok=True)
|
||||
print(f"result directory:\n{save_folder}")
|
||||
|
||||
# ====================== scorer model configs ==============================
|
||||
|
||||
|
||||
if scorer_llm_name == "text-bison":
|
||||
# when prompting text-bison with Cloud API
|
||||
scorer_finetuned_palm_temperature = 0.0
|
||||
scorer_finetuned_palm_max_decode_steps = 1024
|
||||
scorer_finetuned_palm_batch_size = 1
|
||||
scorer_finetuned_palm_num_servers = 1
|
||||
scorer_finetuned_palm_dict = dict()
|
||||
scorer_finetuned_palm_dict["temperature"] = (
|
||||
scorer_finetuned_palm_temperature
|
||||
)
|
||||
scorer_finetuned_palm_dict["num_servers"] = (
|
||||
scorer_finetuned_palm_num_servers
|
||||
)
|
||||
scorer_finetuned_palm_dict["batch_size"] = scorer_finetuned_palm_batch_size
|
||||
scorer_finetuned_palm_dict["max_decode_steps"] = (
|
||||
scorer_finetuned_palm_max_decode_steps
|
||||
)
|
||||
|
||||
call_scorer_finetuned_palm_server_func = functools.partial(
|
||||
prompt_utils.call_palm_server_from_cloud,
|
||||
model="text-bison-001",
|
||||
temperature=scorer_finetuned_palm_dict["temperature"],
|
||||
max_decode_steps=scorer_finetuned_palm_dict["max_decode_steps"],
|
||||
)
|
||||
|
||||
scorer_llm_dict = {
|
||||
"model_type": scorer_llm_name.lower(),
|
||||
}
|
||||
scorer_llm_dict.update(scorer_finetuned_palm_dict)
|
||||
call_scorer_server_func = call_scorer_finetuned_palm_server_func
|
||||
|
||||
|
||||
elif scorer_llm_name.lower() in {"gpt-3.5-turbo", "gpt-4", "local"}:
|
||||
# 改成调用本地vLLM版本的函数
|
||||
scorer_gpt_max_decode_steps = 1024
|
||||
# scorer_gpt_max_decode_steps = 512
|
||||
scorer_gpt_temperature = 0.0
|
||||
|
||||
scorer_llm_dict = {
|
||||
"model_type": scorer_llm_name.lower(),
|
||||
"max_decode_steps": scorer_gpt_max_decode_steps,
|
||||
"temperature": scorer_gpt_temperature,
|
||||
"num_decodes": 1,
|
||||
"batch_size": 1,
|
||||
"num_servers": 1,
|
||||
}
|
||||
|
||||
call_scorer_server_func = functools.partial(
|
||||
prompt_utils.call_openai_server_func, # 你本地实现的vLLM调用函数
|
||||
max_decode_steps=scorer_gpt_max_decode_steps,
|
||||
temperature=scorer_gpt_temperature,
|
||||
local_model_path=FLAGS.local_model_path, # 传入你本地模型路径
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported scorer_llm_name: {scorer_llm_name}")
|
||||
|
||||
|
||||
# ====================== optimizer model configs ============================
|
||||
if optimizer_llm_name.lower() == "text-bison":
|
||||
# when prompting text-bison with Cloud API
|
||||
optimizer_finetuned_palm_temperature = 1.0
|
||||
optimizer_finetuned_palm_num_decodes = 8
|
||||
optimizer_finetuned_palm_max_decode_steps = 1024
|
||||
optimizer_finetuned_palm_batch_size = 1
|
||||
optimizer_finetuned_palm_num_servers = 1
|
||||
optimizer_finetuned_palm_dict = dict()
|
||||
optimizer_finetuned_palm_dict["temperature"] = (
|
||||
optimizer_finetuned_palm_temperature
|
||||
)
|
||||
optimizer_finetuned_palm_dict["num_decodes"] = (
|
||||
optimizer_finetuned_palm_num_decodes
|
||||
)
|
||||
optimizer_finetuned_palm_dict["batch_size"] = (
|
||||
optimizer_finetuned_palm_batch_size
|
||||
)
|
||||
optimizer_finetuned_palm_dict["num_servers"] = (
|
||||
optimizer_finetuned_palm_num_servers
|
||||
)
|
||||
optimizer_finetuned_palm_dict["max_decode_steps"] = (
|
||||
optimizer_finetuned_palm_max_decode_steps
|
||||
)
|
||||
|
||||
call_optimizer_finetuned_palm_server_func = functools.partial(
|
||||
prompt_utils.call_palm_server_from_cloud,
|
||||
model="text-bison-001",
|
||||
temperature=optimizer_finetuned_palm_dict["temperature"],
|
||||
max_decode_steps=optimizer_finetuned_palm_dict["max_decode_steps"],
|
||||
)
|
||||
|
||||
optimizer_llm_dict = {
|
||||
"model_type": optimizer_llm_name.lower(),
|
||||
}
|
||||
optimizer_llm_dict.update(optimizer_finetuned_palm_dict)
|
||||
call_optimizer_server_func = call_optimizer_finetuned_palm_server_func
|
||||
|
||||
elif optimizer_llm_name.lower() in {"gpt-3.5-turbo", "gpt-4", "local"}:
|
||||
# 用本地 vLLM 版本替代调用
|
||||
optimizer_gpt_max_decode_steps = 512
|
||||
|
||||
optimizer_gpt_temperature = 1.0
|
||||
|
||||
optimizer_llm_dict = {
|
||||
"max_decode_steps": optimizer_gpt_max_decode_steps,
|
||||
"temperature": optimizer_gpt_temperature,
|
||||
"batch_size": 1,
|
||||
"num_decodes": 1,
|
||||
}
|
||||
|
||||
call_optimizer_server_func = functools.partial(
|
||||
prompt_utils.call_openai_server_func, # 你写的本地vLLM调用接口
|
||||
max_decode_steps=optimizer_gpt_max_decode_steps,
|
||||
temperature=optimizer_gpt_temperature,
|
||||
local_model_path=FLAGS.local_model_path,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported optimizer_llm_name: {optimizer_llm_name}")
|
||||
|
||||
|
||||
|
||||
# ====================== try calling the servers ============================
|
||||
print("\n======== testing the scorer and optimizer servers ===========")
|
||||
scorer_test_output = call_scorer_server_func(
|
||||
"Does the sun rise from the north? Just answer yes or no."
|
||||
)
|
||||
print(f"number of scorer output decodes: {len(scorer_test_output)}")
|
||||
print(f"scorer test output: {scorer_test_output}")
|
||||
optimizer_test_output = call_optimizer_server_func(
|
||||
"Does the sun rise from the north? Just answer yes or no.",
|
||||
temperature=1.0,
|
||||
)
|
||||
print(f"number of optimizer output decodes: {len(optimizer_test_output)}")
|
||||
print(f"optimizer test output: {optimizer_test_output}")
|
||||
print("Finished testing the servers.")
|
||||
|
||||
# ====================== read data ============================
|
||||
print("\n================ prompt optimization settings ==============")
|
||||
# from https://github.com/hendrycks/test/blob/master/categories.py
|
||||
subcategories = {
|
||||
"abstract_algebra": ["math"],
|
||||
"anatomy": ["health"],
|
||||
"astronomy": ["physics"],
|
||||
"business_ethics": ["business"],
|
||||
"clinical_knowledge": ["health"],
|
||||
"college_biology": ["biology"],
|
||||
"college_chemistry": ["chemistry"],
|
||||
"college_computer_science": ["computer science"],
|
||||
"college_mathematics": ["math"],
|
||||
"college_medicine": ["health"],
|
||||
"college_physics": ["physics"],
|
||||
"computer_security": ["computer science"],
|
||||
"conceptual_physics": ["physics"],
|
||||
"econometrics": ["economics"],
|
||||
"electrical_engineering": ["engineering"],
|
||||
"elementary_mathematics": ["math"],
|
||||
"formal_logic": ["philosophy"],
|
||||
"global_facts": ["other"],
|
||||
"high_school_biology": ["biology"],
|
||||
"high_school_chemistry": ["chemistry"],
|
||||
"high_school_computer_science": ["computer science"],
|
||||
"high_school_european_history": ["history"],
|
||||
"high_school_geography": ["geography"],
|
||||
"high_school_government_and_politics": ["politics"],
|
||||
"high_school_macroeconomics": ["economics"],
|
||||
"high_school_mathematics": ["math"],
|
||||
"high_school_microeconomics": ["economics"],
|
||||
"high_school_physics": ["physics"],
|
||||
"high_school_psychology": ["psychology"],
|
||||
"high_school_statistics": ["math"],
|
||||
"high_school_us_history": ["history"],
|
||||
"high_school_world_history": ["history"],
|
||||
"human_aging": ["health"],
|
||||
"human_sexuality": ["culture"],
|
||||
"international_law": ["law"],
|
||||
"jurisprudence": ["law"],
|
||||
"logical_fallacies": ["philosophy"],
|
||||
"machine_learning": ["computer science"],
|
||||
"management": ["business"],
|
||||
"marketing": ["business"],
|
||||
"medical_genetics": ["health"],
|
||||
"miscellaneous": ["other"],
|
||||
"moral_disputes": ["philosophy"],
|
||||
"moral_scenarios": ["philosophy"],
|
||||
"nutrition": ["health"],
|
||||
"philosophy": ["philosophy"],
|
||||
"prehistory": ["history"],
|
||||
"professional_accounting": ["other"],
|
||||
"professional_law": ["law"],
|
||||
"professional_medicine": ["health"],
|
||||
"professional_psychology": ["psychology"],
|
||||
"public_relations": ["politics"],
|
||||
"security_studies": ["politics"],
|
||||
"sociology": ["culture"],
|
||||
"us_foreign_policy": ["politics"],
|
||||
"virology": ["health"],
|
||||
"world_religions": ["philosophy"],
|
||||
}
|
||||
|
||||
categories = {
|
||||
"STEM": [
|
||||
"physics",
|
||||
"chemistry",
|
||||
"biology",
|
||||
"computer science",
|
||||
"math",
|
||||
"engineering",
|
||||
],
|
||||
"humanities": ["history", "philosophy", "law"],
|
||||
"social sciences": [
|
||||
"politics",
|
||||
"culture",
|
||||
"economics",
|
||||
"geography",
|
||||
"psychology",
|
||||
],
|
||||
"other (business, health, misc.)": ["other", "business", "health"],
|
||||
}
|
||||
|
||||
if dataset_name == "mmlu":
|
||||
|
||||
category_names = [task_name]
|
||||
folder_name = "test" # one of {'auxiliary_train', 'dev', 'val', 'test'}
|
||||
task_names = []
|
||||
for task_csv_name in os.listdir(
|
||||
os.path.join(root_data_folder_path, folder_name)
|
||||
):
|
||||
task_names.append(task_csv_name.split(".")[0])
|
||||
|
||||
tasks_in_category = []
|
||||
for category_name in category_names:
|
||||
for task_name in task_names:
|
||||
for subname in subcategories:
|
||||
if subname in task_name:
|
||||
if subcategories[subname][0] in categories[category_name]:
|
||||
tasks_in_category.append(task_name)
|
||||
break
|
||||
|
||||
tasks_all = [(folder_name, task_name) for task_name in tasks_in_category]
|
||||
multiple_choice_tasks = set([item[1] for item in tasks_all])
|
||||
boolean_tasks = set()
|
||||
numerical_output_tasks = set()
|
||||
|
||||
|
||||
elif dataset_name == "bbh":
|
||||
tasks_all = [task_name]
|
||||
assert (
|
||||
len(tasks_all) == 1
|
||||
), "for now only support prompt optimization on one BBH task"
|
||||
|
||||
|
||||
numerical_output_tasks = {
|
||||
"object_counting",
|
||||
"multistep_arithmetic_two",
|
||||
}
|
||||
|
||||
multiple_choice_tasks = {
|
||||
"date_understanding",
|
||||
"disambiguation_qa",
|
||||
"geometric_shapes",
|
||||
"hyperbaton",
|
||||
"logical_deduction_five_objects",
|
||||
"logical_deduction_seven_objects",
|
||||
"logical_deduction_three_objects",
|
||||
"movie_recommendation",
|
||||
"penguins_in_a_table",
|
||||
"reasoning_about_colored_objects",
|
||||
"ruin_names",
|
||||
"salient_translation_error_detection",
|
||||
"snarks",
|
||||
"temporal_sequences",
|
||||
"tracking_shuffled_objects_five_objects",
|
||||
"tracking_shuffled_objects_seven_objects",
|
||||
"tracking_shuffled_objects_three_objects",
|
||||
}
|
||||
|
||||
boolean_tasks = {
|
||||
"boolean_expressions", # True or False
|
||||
"causal_judgement", # yes or no
|
||||
"formal_fallacies", # valid or invalid
|
||||
"navigate", # yes or no
|
||||
"sports_understanding", # yes or no
|
||||
"web_of_lies", # yes or no
|
||||
}
|
||||
|
||||
else:
|
||||
assert dataset_name in {"gsm8k"}
|
||||
tasks_all = [task_name]
|
||||
multiple_choice_tasks = set()
|
||||
boolean_tasks = set()
|
||||
numerical_output_tasks = set(tasks_all)
|
||||
|
||||
if dataset_name == "mmlu":
|
||||
raw_data = pd.DataFrame()
|
||||
prediction_treat_as_number = False
|
||||
prediction_treat_as_bool = False
|
||||
elif dataset_name == "bbh":
|
||||
raw_data = []
|
||||
prediction_treat_as_number = bool(
|
||||
tasks_all[0] in numerical_output_tasks
|
||||
) # for now only check the first task
|
||||
prediction_treat_as_bool = bool(
|
||||
tasks_all[0] in boolean_tasks
|
||||
) # for now only check the first task
|
||||
print(
|
||||
f"prediction_treat_as_number: {prediction_treat_as_number},"
|
||||
f" prediction_treat_as_bool: {prediction_treat_as_bool}"
|
||||
)
|
||||
else:
|
||||
assert dataset_name == "gsm8k"
|
||||
raw_data = pd.DataFrame()
|
||||
prediction_treat_as_number = True
|
||||
prediction_treat_as_bool = False
|
||||
|
||||
for t in tasks_all:
|
||||
if dataset_name == "mmlu":
|
||||
folder_name = t[0]
|
||||
task_name = t[1]
|
||||
single_task_df = pd.read_csv(
|
||||
os.path.join(root_data_folder_path, f"{folder_name}/{task_name}.csv"),
|
||||
index_col=None,
|
||||
header=None,
|
||||
)
|
||||
raw_data = pd.concat([raw_data, single_task_df])
|
||||
elif dataset_name == "bbh":
|
||||
task_name = t
|
||||
single_task_list = opt_utils.load_bbh_task_data(
|
||||
task_name, base_dir=root_data_folder_path
|
||||
)
|
||||
raw_data += single_task_list
|
||||
else:
|
||||
assert dataset_name == "gsm8k"
|
||||
task_name = t
|
||||
f_gsm = os.path.join(root_data_folder_path, f"gsm_{task_name}.tsv")
|
||||
single_task_df = pd.read_csv(f_gsm, sep="\t", header=None)
|
||||
raw_data = pd.concat([raw_data, single_task_df])
|
||||
|
||||
if dataset_name == "mmlu":
|
||||
num_examples = raw_data.shape[0]
|
||||
elif dataset_name == "bbh":
|
||||
num_examples = len(raw_data)
|
||||
else:
|
||||
assert dataset_name in {"gsm8k"}
|
||||
num_examples = raw_data.shape[0]
|
||||
print(f"number of examples in the current task: {num_examples}")
|
||||
|
||||
# ================ split data into train/val/test ==========================
|
||||
if dataset_name == "mmlu":
|
||||
train_ratio = 0.8
|
||||
|
||||
eval_ratio = 0.2
|
||||
elif dataset_name == "gsm8k":
|
||||
# train_ratio = 0.035
|
||||
train_ratio = 0.01 # 原来是 0.035,改成 0.01,约 74 条
|
||||
eval_ratio = 0
|
||||
else:
|
||||
assert dataset_name == "bbh"
|
||||
train_ratio = 0.2
|
||||
eval_ratio = 0
|
||||
|
||||
|
||||
assert train_ratio + eval_ratio <= 1
|
||||
test_ratio = 1 - train_ratio - eval_ratio
|
||||
print(
|
||||
f"train_ratio: {train_ratio}, eval_ratio: {eval_ratio}, "
|
||||
f"test_ratio: {test_ratio}"
|
||||
)
|
||||
np.random.seed(0)
|
||||
train_index = np.sort(
|
||||
np.array(
|
||||
np.random.choice(
|
||||
num_examples, size=int(train_ratio * num_examples), replace=False
|
||||
)
|
||||
)
|
||||
)
|
||||
eval_and_test_index = np.sort(
|
||||
np.array(list(set(np.arange(num_examples)) - set(train_index)))
|
||||
)
|
||||
eval_index = np.sort(
|
||||
np.array(
|
||||
np.random.choice(
|
||||
eval_and_test_index,
|
||||
size=int(eval_ratio * num_examples),
|
||||
replace=False,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# ========== set other optimization experiment hyperparameters ==============
|
||||
if scorer_llm_name == "text-bison":
|
||||
old_instruction_score_threshold = 0.0
|
||||
# old_instruction_score_threshold = 0.15 # for GSM8K
|
||||
else:
|
||||
assert scorer_llm_name in {"gpt-3.5-turbo", "gpt-4", "local"}
|
||||
old_instruction_score_threshold = 0.3
|
||||
|
||||
if scorer_llm_name == "text-bison":
|
||||
extract_final_answer_by_prompting_again = False
|
||||
include_qa = False
|
||||
evaluate_in_parallel = False
|
||||
else:
|
||||
assert scorer_llm_name in {"gpt-3.5-turbo", "gpt-4", "local"}
|
||||
extract_final_answer_by_prompting_again = False
|
||||
include_qa = False
|
||||
evaluate_in_parallel = False
|
||||
|
||||
optimizer_llm_temperature = optimizer_llm_dict["temperature"]
|
||||
|
||||
|
||||
# num_few_shot_questions_for_instruction_refinement = 3
|
||||
num_few_shot_questions_for_instruction_refinement = 1 # 减少 few-shot 例子数
|
||||
|
||||
# num_generated_instructions_in_each_step = 8
|
||||
num_generated_instructions_in_each_step = 2 # 每步只生成 2 条指令
|
||||
|
||||
# num_search_steps = 200
|
||||
num_search_steps = 3 # 原来是 200,改成 3 步即可
|
||||
|
||||
|
||||
initial_instructions = [
|
||||
"Let's solve the problem.",
|
||||
# "",
|
||||
# "The answer is",
|
||||
]
|
||||
few_shot_qa_pairs = True
|
||||
# one of {'accumulative_most_frequent', 'current_most_frequent', 'random',
|
||||
# 'constant'}
|
||||
few_shot_selection_criteria = "random"
|
||||
# whether to evaluate generated instructions on the exemplars in meta-prompt
|
||||
evaluate_generated_ins_on_few_shot = False
|
||||
# whether to evaluate old instructions on the exemplars in the meta-prompt
|
||||
evaluate_old_ins_on_few_shot = False
|
||||
# every this number of steps, compute the accuracies of current-step
|
||||
# instructions on the validation set
|
||||
# eval_interval = 3
|
||||
eval_interval = 1 # 每步就 eval 一次,及时看到结果
|
||||
# eval_interval = 10
|
||||
max_num_instructions = (
|
||||
20 # the maximum number of instructions and scores in the meta-prompt
|
||||
)
|
||||
# The number of buckets when converting scores to integers in the meta-prompt.
|
||||
num_score_buckets = 100
|
||||
# whether to put old instructions and scores to before exemplars in
|
||||
# the meta-prompt
|
||||
meta_prompt_instructions_before_exemplars = True
|
||||
|
||||
# ===================== run prompt optimization ======================
|
||||
|
||||
assert few_shot_selection_criteria in {
|
||||
"accumulative_most_frequent",
|
||||
"current_most_frequent",
|
||||
"random",
|
||||
"constant",
|
||||
}
|
||||
evolution_kwargs = {
|
||||
"num_search_steps": num_search_steps,
|
||||
"old_instruction_score_threshold": old_instruction_score_threshold,
|
||||
"scorer_llm_dict": scorer_llm_dict,
|
||||
"optimizer_llm_dict": optimizer_llm_dict,
|
||||
"extract_final_answer_by_prompting_again": (
|
||||
extract_final_answer_by_prompting_again
|
||||
),
|
||||
"include_qa": include_qa,
|
||||
"evaluate_in_parallel": evaluate_in_parallel,
|
||||
"tasks_all": tasks_all,
|
||||
"train_ratio": train_ratio,
|
||||
"eval_ratio": eval_ratio,
|
||||
"test_ratio": test_ratio,
|
||||
"train_index": train_index,
|
||||
"eval_index": eval_index,
|
||||
"dataset_name": dataset_name,
|
||||
"task_name": task_name,
|
||||
"num_examples": num_examples,
|
||||
"root_data_folder_path": root_data_folder_path,
|
||||
"optimizer_llm_temperature": optimizer_llm_temperature,
|
||||
# "optimizer_llm_temperature_schedule": (
|
||||
# optimizer_llm_temperature_schedule
|
||||
# ),
|
||||
# "optimizer_llm_temperature_end": optimizer_llm_temperature_end,
|
||||
"initial_instructions": initial_instructions,
|
||||
"multiple_choice_tasks": multiple_choice_tasks,
|
||||
"raw_data": raw_data,
|
||||
"call_scorer_server_func": call_scorer_server_func,
|
||||
"call_optimizer_server_func": call_optimizer_server_func,
|
||||
"instruction_pos": instruction_pos,
|
||||
"prediction_treat_as_number": prediction_treat_as_number,
|
||||
"prediction_treat_as_bool": prediction_treat_as_bool,
|
||||
"result_by_instruction_folder": result_by_instruction_folder,
|
||||
"few_shot_qa_pairs": few_shot_qa_pairs,
|
||||
"num_score_buckets": num_score_buckets,
|
||||
"max_num_instructions": max_num_instructions,
|
||||
"meta_prompt_type": meta_prompt_type,
|
||||
"meta_prompt_instructions_before_exemplars": (
|
||||
meta_prompt_instructions_before_exemplars
|
||||
),
|
||||
"few_shot_selection_criteria": few_shot_selection_criteria,
|
||||
"optimizer_llm_name": optimizer_llm_name,
|
||||
"num_generated_instructions_in_each_step": (
|
||||
num_generated_instructions_in_each_step
|
||||
),
|
||||
"evaluate_generated_ins_on_few_shot": evaluate_generated_ins_on_few_shot,
|
||||
"num_few_shot_questions_for_instruction_refinement": (
|
||||
num_few_shot_questions_for_instruction_refinement
|
||||
),
|
||||
"evaluate_old_ins_on_few_shot": evaluate_old_ins_on_few_shot,
|
||||
"eval_interval": eval_interval,
|
||||
"save_folder": save_folder,
|
||||
}
|
||||
print("=== 开始优化过程 ===")
|
||||
try:
|
||||
opt_utils.run_evolution(**evolution_kwargs)
|
||||
print("=== 优化完成 ===")
|
||||
except Exception as e:
|
||||
import traceback
|
||||
print(f"!!! 优化失败: {e} !!!", file=sys.stderr)
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(main)
|
||||
|
||||
Reference in New Issue
Block a user