原始代码
This commit is contained in:
220
prompt_utils.py
Normal file
220
prompt_utils.py
Normal file
@@ -0,0 +1,220 @@
|
||||
# Copyright 2023 The OPRO Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""The utility functions for prompting GPT and Google Cloud models."""
|
||||
import openai
|
||||
import time
|
||||
try:
|
||||
import google.generativeai as palm
|
||||
except Exception:
|
||||
palm = None
|
||||
try:
|
||||
from vllm import LLM, SamplingParams
|
||||
except Exception:
|
||||
LLM = None
|
||||
SamplingParams = None
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
# 缓存 vLLM 实例,避免重复加载
|
||||
_llm_instance = None
|
||||
|
||||
def get_llm(local_model_path):
|
||||
if LLM is None:
|
||||
raise RuntimeError("vLLM not available")
|
||||
global _llm_instance
|
||||
if _llm_instance is None:
|
||||
assert local_model_path is not None, "model_path cannot be None"
|
||||
local_model_path = str(Path(local_model_path).resolve())
|
||||
_llm_instance = LLM(
|
||||
model=local_model_path,
|
||||
dtype="bfloat16",
|
||||
tensor_parallel_size=8,
|
||||
max_num_batched_tokens=8192,
|
||||
max_num_seqs=64,
|
||||
gpu_memory_utilization=0.7,
|
||||
enforce_eager=True,
|
||||
block_size=16,
|
||||
enable_chunked_prefill=True,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
return _llm_instance
|
||||
|
||||
def call_local_server_single_prompt(prompt, local_model_path=None, temperature=0.8, max_decode_steps=512, **kwargs):
|
||||
"""
|
||||
使用本地 vLLM 模型生成单个 prompt 的响应,替代原本 OpenAI API。
|
||||
"""
|
||||
llm = get_llm(local_model_path)
|
||||
sampling_params = SamplingParams(
|
||||
temperature=temperature,
|
||||
top_p=0.9,
|
||||
max_tokens=max_decode_steps,
|
||||
skip_special_tokens=True # 避免特殊字符触发协议错误
|
||||
)
|
||||
outputs = llm.generate([prompt], sampling_params)
|
||||
return outputs[0].outputs[0].text
|
||||
|
||||
def call_local_server_func(inputs, local_model_path=None, temperature=0.8, max_decode_steps=512, **kwargs):
|
||||
"""
|
||||
批量处理多个输入 prompt。
|
||||
"""
|
||||
|
||||
assert local_model_path is not None, "local_model_path must be provided"
|
||||
# 强制类型检查
|
||||
if isinstance(inputs, bytes):
|
||||
inputs = inputs.decode('utf-8')
|
||||
outputs = []
|
||||
for input_str in inputs:
|
||||
output = call_local_server_single_prompt(
|
||||
input_str,
|
||||
local_model_path=local_model_path,
|
||||
temperature=temperature,
|
||||
max_decode_steps=max_decode_steps
|
||||
)
|
||||
outputs.append(output)
|
||||
return outputs
|
||||
|
||||
|
||||
def call_openai_server_single_prompt(
|
||||
prompt, model="gpt-3.5-turbo", max_decode_steps=20, temperature=0.8
|
||||
):
|
||||
"""The function to call OpenAI server with an input string."""
|
||||
try:
|
||||
completion = openai.ChatCompletion.create(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_decode_steps,
|
||||
messages=[
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
)
|
||||
return completion.choices[0].message.content
|
||||
# 函数捕获了 6 类常见异常,并在遇到错误时自动重试:
|
||||
except openai.error.Timeout as e: # API 请求超时
|
||||
retry_time = e.retry_after if hasattr(e, "retry_after") else 30
|
||||
print(f"Timeout error occurred. Retrying in {retry_time} seconds...")
|
||||
time.sleep(retry_time)
|
||||
return call_openai_server_single_prompt(
|
||||
prompt, max_decode_steps=max_decode_steps, temperature=temperature
|
||||
)
|
||||
|
||||
except openai.error.RateLimitError as e: #请求频率超限(Rate Limit)
|
||||
retry_time = e.retry_after if hasattr(e, "retry_after") else 30
|
||||
print(f"Rate limit exceeded. Retrying in {retry_time} seconds...")
|
||||
time.sleep(retry_time)
|
||||
return call_openai_server_single_prompt(
|
||||
prompt, max_decode_steps=max_decode_steps, temperature=temperature
|
||||
)
|
||||
|
||||
except openai.error.APIError as e: # API 错误(如服务器错误)
|
||||
retry_time = e.retry_after if hasattr(e, "retry_after") else 30
|
||||
print(f"API error occurred. Retrying in {retry_time} seconds...")
|
||||
time.sleep(retry_time)
|
||||
return call_openai_server_single_prompt(
|
||||
prompt, max_decode_steps=max_decode_steps, temperature=temperature
|
||||
)
|
||||
|
||||
except openai.error.APIConnectionError as e: # API 连接错误(如网络问题)
|
||||
retry_time = e.retry_after if hasattr(e, "retry_after") else 30
|
||||
print(f"API connection error occurred. Retrying in {retry_time} seconds...")
|
||||
time.sleep(retry_time)
|
||||
return call_openai_server_single_prompt(
|
||||
prompt, max_decode_steps=max_decode_steps, temperature=temperature
|
||||
)
|
||||
|
||||
except openai.error.ServiceUnavailableError as e: # 服务不可用(如服务器维护)
|
||||
retry_time = e.retry_after if hasattr(e, "retry_after") else 30
|
||||
print(f"Service unavailable. Retrying in {retry_time} seconds...")
|
||||
time.sleep(retry_time)
|
||||
return call_openai_server_single_prompt(
|
||||
prompt, max_decode_steps=max_decode_steps, temperature=temperature
|
||||
)
|
||||
|
||||
except OSError as e: # 操作系统级连接错误(如网络中断)
|
||||
retry_time = 5 # Adjust the retry time as needed
|
||||
print(
|
||||
f"Connection error occurred: {e}. Retrying in {retry_time} seconds..."
|
||||
)
|
||||
time.sleep(retry_time)
|
||||
return call_openai_server_single_prompt(
|
||||
prompt, max_decode_steps=max_decode_steps, temperature=temperature
|
||||
)
|
||||
|
||||
|
||||
def call_openai_server_func( #批量处理多个输入提示(prompts),通过 OpenAI API 并行或顺序获取多个生成结果。
|
||||
inputs, model="gpt-3.5-turbo", max_decode_steps=20, temperature=0.8
|
||||
):
|
||||
"""The function to call OpenAI server with a list of input strings."""
|
||||
if isinstance(inputs, str): # 将单个字符串转为列表,统一处理
|
||||
inputs = [inputs]
|
||||
outputs = []
|
||||
for input_str in inputs:
|
||||
output = call_openai_server_single_prompt(
|
||||
input_str,
|
||||
model=model,
|
||||
max_decode_steps=max_decode_steps,
|
||||
temperature=temperature,
|
||||
)
|
||||
outputs.append(output)
|
||||
return outputs
|
||||
|
||||
#通过 Google PaLM API(Cloud 版)调用 text-bison模型生成文本,并包含基本错误处理和自动重试机制。
|
||||
def call_palm_server_from_cloud(
|
||||
input_text, model="text-bison-001", max_decode_steps=20, temperature=0.8
|
||||
):
|
||||
if palm is None:
|
||||
raise RuntimeError("google.generativeai not available")
|
||||
assert isinstance(input_text, str)
|
||||
assert model == "text-bison-001"
|
||||
all_model_names = [
|
||||
m
|
||||
for m in palm.list_models()
|
||||
if "generateText" in m.supported_generation_methods
|
||||
]
|
||||
model_name = all_model_names[0].name
|
||||
try:
|
||||
completion = palm.generate_text(
|
||||
model=model_name,
|
||||
prompt=input_text,
|
||||
temperature=temperature,
|
||||
max_output_tokens=max_decode_steps,
|
||||
)
|
||||
output_text = completion.result
|
||||
return [output_text]
|
||||
except Exception:
|
||||
retry_time = 10
|
||||
time.sleep(retry_time)
|
||||
return call_palm_server_from_cloud(
|
||||
input_text, max_decode_steps=max_decode_steps, temperature=temperature
|
||||
)
|
||||
|
||||
def refine_instruction(query: str) -> str:
|
||||
return f"""
|
||||
你是一个“问题澄清与重写助手”。
|
||||
请根据用户的原始问题:
|
||||
【{query}】
|
||||
生成不少于20条多角度、可直接执行的问题改写,每行一条。
|
||||
"""
|
||||
|
||||
def refine_instruction_with_history(query: str, rejected_list: list) -> str:
|
||||
rejected_text = "\n".join(f"- {r}" for r in rejected_list) if rejected_list else ""
|
||||
return f"""
|
||||
你是一个“问题澄清与重写助手”。
|
||||
原始问题:
|
||||
{query}
|
||||
|
||||
以下改写已被否定:
|
||||
{rejected_text}
|
||||
|
||||
请从新的角度重新生成至少20条不同的改写问题,每条单独一行。
|
||||
"""
|
||||
Reference in New Issue
Block a user