Compare commits
10 Commits
dd5339de32
...
0b5319b31c
| Author | SHA1 | Date | |
|---|---|---|---|
| 0b5319b31c | |||
| 6426b73a5e | |||
| 26f8e0c648 | |||
| 65cdcf29dc | |||
| 602875b08c | |||
| da30a0999c | |||
| 1376d60ed5 | |||
| 8f52fad41c | |||
|
|
b4934dfe6d | ||
|
|
800bed638f |
28
.dockerignore
Normal file
28
.dockerignore
Normal file
@@ -0,0 +1,28 @@
|
||||
__pycache__
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.pyd
|
||||
.Python
|
||||
*.so
|
||||
*.egg
|
||||
*.egg-info
|
||||
dist
|
||||
build
|
||||
.git
|
||||
.gitignore
|
||||
.vscode
|
||||
.idea
|
||||
*.md
|
||||
!README.md
|
||||
# Include pre-downloaded Ollama binary for offline build
|
||||
!ollama-linux-amd64.tgz
|
||||
local_docs
|
||||
examples
|
||||
outputs
|
||||
.DS_Store
|
||||
*.log
|
||||
.env
|
||||
.venv
|
||||
venv
|
||||
env
|
||||
|
||||
171
.gitignore
vendored
Normal file
171
.gitignore
vendored
Normal file
@@ -0,0 +1,171 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# pipenv
|
||||
Pipfile.lock
|
||||
|
||||
# poetry
|
||||
poetry.lock
|
||||
|
||||
# pdm
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# IDEs
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
.DS_Store
|
||||
|
||||
# Project specific
|
||||
outputs/
|
||||
*.jsonl
|
||||
*.log
|
||||
local_docs/
|
||||
|
||||
# Docker build artifacts (DO NOT commit these - they are huge!)
|
||||
ollama-models/
|
||||
*.tar
|
||||
ollama-linux-amd64.tgz
|
||||
system-prompt-optimizer-*.tar
|
||||
*.tar.gz
|
||||
|
||||
# Backup files from scripts
|
||||
*.bak
|
||||
|
||||
# Node modules (if any frontend dependencies)
|
||||
node_modules/
|
||||
package-lock.json
|
||||
yarn.lock
|
||||
|
||||
# Temporary files
|
||||
*.tmp
|
||||
*.temp
|
||||
.trae/
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
## 问题分析
|
||||
- 浏览器报错 `net::ERR_ABORTED http://0.0.0.0:8010/`,常见于预览环境对 `0.0.0.0` 的访问被中止或跳转流程未完成。
|
||||
- 现有后端已启动并挂载静态目录到 `/`,但预览器可能对根路径加载敏感,或端口/主机解析不一致。
|
||||
- 目前 API 路由与静态挂载并存,使用相对 `fetch('/query')` 正常;问题主要是根页面加载。
|
||||
|
||||
## 修复方案
|
||||
1. 主机与端口访问
|
||||
- 推荐使用 `http://127.0.0.1:8010/` 或 `http://localhost:8010/` 访问,而不是 `0.0.0.0`。
|
||||
- 新增 `/health` 路由用于快速诊断服务是否运行。
|
||||
|
||||
2. 静态页面挂载位置
|
||||
- 将静态目录从根 `/` 改为 `/ui` 挂载,降低与根路径的潜在冲突。
|
||||
- 新增 `GET /` 路由,返回 `frontend/index.html` 或 302 跳转到 `/ui/index.html`。
|
||||
|
||||
3. 前端请求与错误处理
|
||||
- 保持相对路径 `fetch('/query')`、`/select`、`/reject`,保持同源;增强错误提示(显示响应状态与内容)。
|
||||
- 若需要跨源(前端单独部署),补充 CORS:允许前端源访问后端 API。
|
||||
|
||||
4. 诊断与验证
|
||||
- 使用 `curl http://127.0.0.1:8010/health` 验证健康。
|
||||
- 使用 `curl` 端到端:`/query`(新会话与带 `session_id` 再生)、`/select`(生成答案)。
|
||||
- 浏览器打开 `/ui/` 页面,执行完整流程:开始生成 → 拒绝并再生 → 选择并出答案。
|
||||
|
||||
## 具体改动清单
|
||||
- `_qwen_xinference_demo/api.py`
|
||||
- 添加 `GET /health` 路由返回 `{status:"ok"}`。
|
||||
- 将 `StaticFiles(directory="frontend", html=True)` 从 `/` 挂载到 `/ui`。
|
||||
- 添加 `GET /` 路由,返回 `index.html` 或重定向到 `/ui/index.html`。
|
||||
- `frontend/index.html`
|
||||
- 增强错误显示:同时显示响应状态码与文本(提升诊断能力)。
|
||||
|
||||
## 后续增强(可选)
|
||||
- 为 `/query` 与 `/select` 增加耗时、来源日志,便于问题排查。
|
||||
- 在页面上展示历史候选与拒绝原因列表,提升可观测性。
|
||||
- 提供配置项切换嵌入优先级(Xinference/Ollama)。
|
||||
|
||||
请确认是否按以上方案进行修改与验证,我将立即实施并完成端到端测试。
|
||||
461
DEPLOYMENT.md
Normal file
461
DEPLOYMENT.md
Normal file
@@ -0,0 +1,461 @@
|
||||
# Docker 部署指南
|
||||
|
||||
本文档说明如何在无外网访问的服务器上部署系统提示词优化工具。
|
||||
|
||||
## 部署方案
|
||||
|
||||
本项目提供两种部署方案:
|
||||
|
||||
### 方案 A: All-in-One 镜像(推荐,适用于无外网服务器)
|
||||
|
||||
**优点**:
|
||||
- 包含所有依赖:应用代码 + Ollama + LLM 模型
|
||||
- 一个镜像文件,部署简单
|
||||
- 无需在目标服务器上安装任何额外软件(除了 Docker)
|
||||
|
||||
**缺点**:
|
||||
- 镜像文件很大(10-20GB)
|
||||
- 传输时间较长
|
||||
|
||||
### 方案 B: 分离部署(适用于已有 Ollama 的服务器)
|
||||
|
||||
**优点**:
|
||||
- 镜像文件较小(~500MB)
|
||||
- 可以复用现有的 Ollama 服务
|
||||
|
||||
**缺点**:
|
||||
- 需要在目标服务器上单独安装和配置 Ollama
|
||||
- 需要手动下载模型
|
||||
|
||||
---
|
||||
|
||||
## 方案 A: All-in-One 部署(推荐)
|
||||
|
||||
### 前置要求
|
||||
|
||||
#### 在开发机器上(有外网访问)
|
||||
|
||||
1. **Docker** 已安装
|
||||
2. **Ollama** 已安装并运行
|
||||
3. **磁盘空间**:至少 30GB 可用空间
|
||||
4. 已下载所需的 Ollama 模型:
|
||||
- `qwen3:14b` (主模型,~8GB)
|
||||
- `qwen3-embedding:4b` (嵌入模型,~2GB)
|
||||
|
||||
#### 在目标服务器上(无外网访问)
|
||||
|
||||
1. **Docker** 已安装
|
||||
2. **磁盘空间**:至少 25GB 可用空间
|
||||
|
||||
### 部署步骤
|
||||
|
||||
#### 步骤 1: 下载所需的 Ollama 模型
|
||||
|
||||
在开发机器上,确保已下载所需模型:
|
||||
|
||||
```bash
|
||||
# 下载主模型(约 8GB)
|
||||
ollama pull qwen3:14b
|
||||
|
||||
# 下载嵌入模型(约 2GB)
|
||||
ollama pull qwen3-embedding:4b
|
||||
|
||||
# 验证模型已下载
|
||||
ollama list
|
||||
```
|
||||
|
||||
#### 步骤 2: 导出 Ollama 模型
|
||||
|
||||
```bash
|
||||
# 运行导出脚本
|
||||
./export-ollama-models.sh
|
||||
```
|
||||
|
||||
这将创建 `ollama-models/` 目录,包含所有模型文件。
|
||||
|
||||
#### 步骤 3: 构建 All-in-One Docker 镜像
|
||||
|
||||
```bash
|
||||
# 运行构建脚本(推荐)
|
||||
./build-allinone.sh
|
||||
|
||||
# 或手动构建
|
||||
docker build -f Dockerfile.allinone -t system-prompt-optimizer:allinone .
|
||||
```
|
||||
|
||||
**注意**:构建过程可能需要 10-30 分钟,取决于机器性能。
|
||||
|
||||
#### 步骤 4: 导出 Docker 镜像
|
||||
|
||||
如果使用 `build-allinone.sh`,镜像已自动导出。否则手动导出:
|
||||
|
||||
```bash
|
||||
# 导出镜像(约 10-20GB)
|
||||
docker save -o system-prompt-optimizer-allinone.tar system-prompt-optimizer:allinone
|
||||
|
||||
# 验证文件大小
|
||||
ls -lh system-prompt-optimizer-allinone.tar
|
||||
```
|
||||
|
||||
#### 步骤 5: 传输到目标服务器
|
||||
|
||||
使用 scp、U盘或其他方式传输镜像文件:
|
||||
|
||||
```bash
|
||||
# 使用 scp(如果网络可达)
|
||||
scp system-prompt-optimizer-allinone.tar user@server:/path/
|
||||
|
||||
# 或使用 rsync(支持断点续传)
|
||||
rsync -avP --progress system-prompt-optimizer-allinone.tar user@server:/path/
|
||||
|
||||
# 或使用 U盘/移动硬盘物理传输
|
||||
```
|
||||
|
||||
#### 步骤 6: 在目标服务器上加载镜像
|
||||
|
||||
```bash
|
||||
# 加载镜像(需要几分钟)
|
||||
docker load -i system-prompt-optimizer-allinone.tar
|
||||
|
||||
# 如果遇到权限错误,使用 sudo
|
||||
# sudo docker load -i system-prompt-optimizer-allinone.tar
|
||||
|
||||
# 验证镜像已加载
|
||||
docker images | grep system-prompt-optimizer
|
||||
```
|
||||
|
||||
#### 步骤 7: 启动服务
|
||||
|
||||
**CPU 模式(默认):**
|
||||
|
||||
```bash
|
||||
# 启动容器(推荐:仅暴露 Web 端口)
|
||||
docker run -d \
|
||||
--name system-prompt-optimizer \
|
||||
-p 8010:8010 \
|
||||
--restart unless-stopped \
|
||||
system-prompt-optimizer:allinone
|
||||
|
||||
# 查看启动日志
|
||||
docker logs -f system-prompt-optimizer
|
||||
```
|
||||
|
||||
**GPU 模式(推荐,如果有 NVIDIA GPU):**
|
||||
|
||||
```bash
|
||||
# 使用所有可用 GPU(推荐)
|
||||
docker run -d \
|
||||
--name system-prompt-optimizer \
|
||||
--gpus all \
|
||||
-p 8010:8010 \
|
||||
--restart unless-stopped \
|
||||
system-prompt-optimizer:allinone
|
||||
|
||||
# 或指定特定 GPU
|
||||
docker run -d \
|
||||
--name system-prompt-optimizer \
|
||||
--gpus '"device=0"' \
|
||||
-p 8010:8010 \
|
||||
--restart unless-stopped \
|
||||
system-prompt-optimizer:allinone
|
||||
|
||||
# 查看启动日志
|
||||
docker logs -f system-prompt-optimizer
|
||||
```
|
||||
|
||||
**GPU 部署前提条件**:
|
||||
- 已安装 NVIDIA 驱动 (`nvidia-smi` 可用)
|
||||
- 已安装 NVIDIA Container Toolkit
|
||||
- GPU 显存 ≥ 10GB (14b 模型) 或 ≥ 6GB (8b 模型)
|
||||
|
||||
**详细 GPU 部署指南**: 参见 [GPU_DEPLOYMENT.md](GPU_DEPLOYMENT.md)
|
||||
|
||||
**重要**:
|
||||
- 首次启动需要等待 30-60 秒(CPU)或 10-20 秒(GPU),Ollama 服务需要初始化
|
||||
- GPU 模式下推理速度提升 5-10 倍
|
||||
- 端口 11434 (Ollama) 是可选的,仅在需要外部访问 Ollama 时暴露
|
||||
- 不暴露 11434 更安全,因为 Ollama API 没有身份验证
|
||||
|
||||
#### 步骤 8: 验证部署
|
||||
|
||||
```bash
|
||||
# 等待服务启动(约 30-60 秒)
|
||||
sleep 60
|
||||
|
||||
# 健康检查
|
||||
curl http://localhost:8010/health
|
||||
|
||||
# 应该返回:
|
||||
# {"status":"ok","version":"0.1.0"}
|
||||
|
||||
# 检查 Ollama 服务
|
||||
curl http://localhost:11434/api/tags
|
||||
|
||||
# 检查可用模型
|
||||
curl http://localhost:8010/models
|
||||
|
||||
# 访问 Web 界面
|
||||
# 浏览器打开: http://<服务器IP>:8010/ui/opro.html
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 方案 B: 分离部署
|
||||
|
||||
### 前置要求
|
||||
|
||||
#### 在目标服务器上
|
||||
|
||||
1. **Docker** 已安装
|
||||
2. **Ollama** 服务已安装并运行
|
||||
3. 已拉取所需的 Ollama 模型:
|
||||
- `qwen3:14b` (主模型)
|
||||
- `qwen3-embedding:4b` (嵌入模型)
|
||||
|
||||
### 部署步骤
|
||||
|
||||
#### 步骤 1: 构建应用镜像
|
||||
|
||||
```bash
|
||||
# 在开发机器上构建
|
||||
docker build -t system-prompt-optimizer:latest .
|
||||
|
||||
# 导出镜像
|
||||
docker save -o system-prompt-optimizer.tar system-prompt-optimizer:latest
|
||||
```
|
||||
|
||||
#### 步骤 2: 传输并加载
|
||||
|
||||
```bash
|
||||
# 传输到目标服务器
|
||||
scp system-prompt-optimizer.tar user@server:/path/
|
||||
|
||||
# 在目标服务器上加载
|
||||
docker load -i system-prompt-optimizer.tar
|
||||
```
|
||||
|
||||
#### 步骤 3: 启动服务
|
||||
|
||||
```bash
|
||||
# 使用 Docker Compose
|
||||
docker-compose up -d
|
||||
|
||||
# 或使用 Docker 命令
|
||||
docker run -d \
|
||||
--name system-prompt-optimizer \
|
||||
-p 8010:8010 \
|
||||
-e OLLAMA_HOST=http://host.docker.internal:11434 \
|
||||
-v $(pwd)/outputs:/app/outputs \
|
||||
--add-host host.docker.internal:host-gateway \
|
||||
--restart unless-stopped \
|
||||
system-prompt-optimizer:latest
|
||||
```
|
||||
|
||||
## 配置说明
|
||||
|
||||
### 环境变量
|
||||
|
||||
在 `docker-compose.yml` 或 `docker run` 命令中可以配置以下环境变量:
|
||||
|
||||
- `OLLAMA_HOST`: Ollama 服务地址(默认: `http://host.docker.internal:11434`)
|
||||
- `PYTHONUNBUFFERED`: Python 输出缓冲(默认: `1`)
|
||||
|
||||
### 端口映射
|
||||
|
||||
- **8010**: Web 界面和 API 端口(必需)
|
||||
- **11434**: Ollama API 端口(可选,仅用于调试或外部访问 Ollama)
|
||||
|
||||
### 数据持久化
|
||||
|
||||
- `./outputs`: 用户反馈日志存储目录(映射到容器内 `/app/outputs`)
|
||||
|
||||
## 故障排查
|
||||
|
||||
### 0. Docker 守护进程连接错误
|
||||
|
||||
**问题**: 运行 `docker` 命令时提示 "Cannot connect to the Docker daemon"
|
||||
|
||||
**症状**:
|
||||
```
|
||||
Cannot connect to the Docker daemon at unix:///var/run/docker.sock. Is the docker daemon running?
|
||||
```
|
||||
|
||||
**解决方案**:
|
||||
|
||||
**方法 1: 检查 Docker 服务状态**
|
||||
```bash
|
||||
# 检查 Docker 是否运行
|
||||
sudo systemctl status docker
|
||||
|
||||
# 如果未运行,启动它
|
||||
sudo systemctl start docker
|
||||
|
||||
# 设置开机自启
|
||||
sudo systemctl enable docker
|
||||
```
|
||||
|
||||
**方法 2: 添加用户到 docker 组(推荐)**
|
||||
```bash
|
||||
# 将当前用户添加到 docker 组
|
||||
sudo usermod -aG docker $USER
|
||||
|
||||
# 应用组变更(需要重新登录或使用 newgrp)
|
||||
newgrp docker
|
||||
|
||||
# 或者直接注销并重新登录
|
||||
|
||||
# 验证
|
||||
docker info
|
||||
```
|
||||
|
||||
**方法 3: 修复 Docker socket 权限**
|
||||
```bash
|
||||
# 检查 socket 权限
|
||||
ls -l /var/run/docker.sock
|
||||
|
||||
# 修复权限
|
||||
sudo chown root:docker /var/run/docker.sock
|
||||
sudo chmod 660 /var/run/docker.sock
|
||||
```
|
||||
|
||||
**方法 4: 临时使用 sudo**
|
||||
```bash
|
||||
# 如果上述方法不可行,使用 sudo 运行 Docker 命令
|
||||
sudo docker load -i system-prompt-optimizer-allinone.tar
|
||||
sudo docker run -d --name system-prompt-optimizer ...
|
||||
```
|
||||
|
||||
**验证修复**:
|
||||
```bash
|
||||
# 应该能正常显示 Docker 信息
|
||||
docker info
|
||||
|
||||
# 应该能看到当前用户在 docker 组中
|
||||
groups | grep docker
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 1. 无法连接 Ollama 服务
|
||||
|
||||
**问题**: 容器内无法访问宿主机的 Ollama 服务
|
||||
|
||||
**解决方案**:
|
||||
```bash
|
||||
# 确保使用了 --add-host 参数
|
||||
--add-host host.docker.internal:host-gateway
|
||||
|
||||
# 或者直接使用宿主机 IP
|
||||
-e OLLAMA_HOST=http://192.168.1.100:11434
|
||||
```
|
||||
|
||||
### 2. 模型不可用(All-in-One 部署)
|
||||
|
||||
**问题**: 容器内模型未正确加载
|
||||
|
||||
**解决方案**:
|
||||
```bash
|
||||
# 进入容器检查
|
||||
docker exec -it system-prompt-optimizer bash
|
||||
|
||||
# 在容器内检查模型
|
||||
ollama list
|
||||
|
||||
# 如果模型不存在,检查模型目录
|
||||
ls -la /root/.ollama/models/
|
||||
|
||||
# 退出容器
|
||||
exit
|
||||
```
|
||||
|
||||
如果模型确实丢失,可能需要重新构建镜像。
|
||||
|
||||
### 3. 模型不可用(分离部署)
|
||||
|
||||
**问题**: Ollama 模型未安装
|
||||
|
||||
**解决方案**:
|
||||
```bash
|
||||
# 在宿主机上拉取模型
|
||||
ollama pull qwen3:14b
|
||||
ollama pull qwen3-embedding:4b
|
||||
|
||||
# 验证模型已安装
|
||||
ollama list
|
||||
```
|
||||
|
||||
### 4. 容器启动失败
|
||||
|
||||
**问题**: 端口被占用或权限问题
|
||||
|
||||
**解决方案**:
|
||||
```bash
|
||||
# 检查端口占用
|
||||
netstat -tulpn | grep 8010
|
||||
netstat -tulpn | grep 11434
|
||||
|
||||
# 更换端口(All-in-One 需要两个端口)
|
||||
docker run -p 8011:8010 -p 11435:11434 ...
|
||||
|
||||
# 查看容器日志
|
||||
docker logs system-prompt-optimizer
|
||||
```
|
||||
|
||||
### 5. 性能问题
|
||||
|
||||
**问题**: 生成速度慢
|
||||
|
||||
**解决方案**:
|
||||
- 确保 Ollama 使用 GPU 加速
|
||||
- 使用更小的模型(如 `qwen3:4b`)
|
||||
- 调整 `config.py` 中的 `GENERATION_POOL_SIZE`
|
||||
|
||||
## 更新部署
|
||||
|
||||
```bash
|
||||
# 1. 在开发机器上重新构建镜像
|
||||
docker build -t system-prompt-optimizer:latest .
|
||||
|
||||
# 2. 导出新镜像
|
||||
docker save -o system-prompt-optimizer-new.tar system-prompt-optimizer:latest
|
||||
|
||||
# 3. 传输到服务器并加载
|
||||
docker load -i system-prompt-optimizer-new.tar
|
||||
|
||||
# 4. 重启服务
|
||||
docker-compose down
|
||||
docker-compose up -d
|
||||
|
||||
# 或使用 docker 命令
|
||||
docker stop system-prompt-optimizer
|
||||
docker rm system-prompt-optimizer
|
||||
docker run -d ... # 使用相同的启动命令
|
||||
```
|
||||
|
||||
## 安全建议
|
||||
|
||||
1. **网络隔离**: 如果不需要外部访问,只绑定到 localhost
|
||||
```bash
|
||||
-p 127.0.0.1:8010:8010
|
||||
```
|
||||
|
||||
2. **防火墙**: 配置防火墙规则限制访问
|
||||
```bash
|
||||
# 只允许特定 IP 访问
|
||||
iptables -A INPUT -p tcp --dport 8010 -s 192.168.1.0/24 -j ACCEPT
|
||||
iptables -A INPUT -p tcp --dport 8010 -j DROP
|
||||
```
|
||||
|
||||
3. **日志管理**: 定期清理日志文件
|
||||
```bash
|
||||
# 限制 Docker 日志大小
|
||||
docker run --log-opt max-size=10m --log-opt max-file=3 ...
|
||||
```
|
||||
|
||||
## 联系支持
|
||||
|
||||
如有问题,请查看:
|
||||
- 应用日志: `docker logs system-prompt-optimizer`
|
||||
- Ollama 日志: `journalctl -u ollama -f`
|
||||
- API 文档: http://localhost:8010/docs
|
||||
|
||||
38
Dockerfile
Normal file
38
Dockerfile
Normal file
@@ -0,0 +1,38 @@
|
||||
FROM python:3.10-slim
|
||||
|
||||
# Set working directory
|
||||
WORKDIR /app
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy requirements file
|
||||
COPY requirements.txt .
|
||||
|
||||
# Install Python dependencies
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy application code
|
||||
COPY _qwen_xinference_demo/ ./_qwen_xinference_demo/
|
||||
COPY frontend/ ./frontend/
|
||||
COPY config.py .
|
||||
|
||||
# Create outputs directory
|
||||
RUN mkdir -p outputs
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8010
|
||||
|
||||
# Set environment variables
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
ENV OLLAMA_HOST=http://host.docker.internal:11434
|
||||
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||
CMD curl -f http://localhost:8010/health || exit 1
|
||||
|
||||
# Run the application
|
||||
CMD ["uvicorn", "_qwen_xinference_demo.api:app", "--host", "0.0.0.0", "--port", "8010"]
|
||||
|
||||
58
Dockerfile.allinone
Normal file
58
Dockerfile.allinone
Normal file
@@ -0,0 +1,58 @@
|
||||
FROM --platform=linux/amd64 python:3.10-slim
|
||||
|
||||
# Set working directory
|
||||
WORKDIR /app
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
curl \
|
||||
ca-certificates \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install Ollama manually for amd64
|
||||
# Copy pre-downloaded Ollama binary to avoid slow downloads during build
|
||||
# Using v0.13.1 (latest stable as of Dec 2024)
|
||||
COPY ollama-linux-amd64.tgz /tmp/ollama-linux-amd64.tgz
|
||||
RUN tar -C /usr -xzf /tmp/ollama-linux-amd64.tgz \
|
||||
&& rm /tmp/ollama-linux-amd64.tgz
|
||||
|
||||
# Copy requirements file
|
||||
COPY requirements.txt .
|
||||
|
||||
# Install Python dependencies
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy application code
|
||||
COPY _qwen_xinference_demo/ ./_qwen_xinference_demo/
|
||||
COPY frontend/ ./frontend/
|
||||
COPY config.py .
|
||||
|
||||
# Create necessary directories
|
||||
RUN mkdir -p outputs /root/.ollama
|
||||
|
||||
# Copy pre-downloaded Ollama models
|
||||
# This includes qwen3:14b and qwen3-embedding:4b
|
||||
COPY ollama-models/ /root/.ollama/
|
||||
|
||||
# Expose ports
|
||||
EXPOSE 8010 11434
|
||||
|
||||
# Set environment variables
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
ENV OLLAMA_HOST=http://localhost:11434
|
||||
# Enable GPU support for Ollama (will auto-detect NVIDIA GPU if available)
|
||||
ENV NVIDIA_VISIBLE_DEVICES=all
|
||||
ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility
|
||||
|
||||
# Copy startup script
|
||||
COPY docker-entrypoint.sh /docker-entrypoint.sh
|
||||
RUN chmod +x /docker-entrypoint.sh
|
||||
|
||||
# Health check
|
||||
# Only check the web application, not Ollama (internal service)
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \
|
||||
CMD curl -f http://localhost:8010/health || exit 1
|
||||
|
||||
# Run the startup script
|
||||
ENTRYPOINT ["/docker-entrypoint.sh"]
|
||||
|
||||
117
QUICK_START.md
Normal file
117
QUICK_START.md
Normal file
@@ -0,0 +1,117 @@
|
||||
# 快速开始指南
|
||||
|
||||
## 离线部署(All-in-One 方案)
|
||||
|
||||
### 在开发机器上(有外网)
|
||||
|
||||
```bash
|
||||
# 1. 下载模型
|
||||
ollama pull qwen3:14b
|
||||
ollama pull qwen3-embedding:4b
|
||||
|
||||
# 2. 导出模型
|
||||
./export-ollama-models.sh
|
||||
|
||||
# 3. 构建并导出 Docker 镜像
|
||||
./build-allinone.sh
|
||||
|
||||
# 4. 传输到目标服务器
|
||||
# 文件: system-prompt-optimizer-allinone.tar (约 10-20GB)
|
||||
scp system-prompt-optimizer-allinone.tar user@server:/path/
|
||||
```
|
||||
|
||||
### 在目标服务器上(无外网)
|
||||
|
||||
```bash
|
||||
# 1. 加载镜像
|
||||
docker load -i system-prompt-optimizer-allinone.tar
|
||||
|
||||
# 2. 启动服务
|
||||
docker run -d \
|
||||
--name system-prompt-optimizer \
|
||||
-p 8010:8010 \
|
||||
-p 11434:11434 \
|
||||
-v $(pwd)/outputs:/app/outputs \
|
||||
--restart unless-stopped \
|
||||
system-prompt-optimizer:allinone
|
||||
|
||||
# 3. 等待启动(约 60 秒)
|
||||
sleep 60
|
||||
|
||||
# 4. 验证
|
||||
curl http://localhost:8010/health
|
||||
curl http://localhost:11434/api/tags
|
||||
|
||||
# 5. 访问界面
|
||||
# http://<服务器IP>:8010/ui/opro.html
|
||||
```
|
||||
|
||||
## 常用命令
|
||||
|
||||
```bash
|
||||
# 查看日志
|
||||
docker logs -f system-prompt-optimizer
|
||||
|
||||
# 重启服务
|
||||
docker restart system-prompt-optimizer
|
||||
|
||||
# 停止服务
|
||||
docker stop system-prompt-optimizer
|
||||
|
||||
# 删除容器
|
||||
docker rm -f system-prompt-optimizer
|
||||
|
||||
# 进入容器
|
||||
docker exec -it system-prompt-optimizer bash
|
||||
|
||||
# 检查模型
|
||||
docker exec -it system-prompt-optimizer ollama list
|
||||
```
|
||||
|
||||
## 端口说明
|
||||
|
||||
- **8010**: Web 界面和 API
|
||||
- **11434**: Ollama 服务(仅 All-in-One 方案需要暴露)
|
||||
|
||||
## 文件说明
|
||||
|
||||
- `system-prompt-optimizer-allinone.tar`: 完整镜像(10-20GB)
|
||||
- `outputs/`: 用户反馈日志目录
|
||||
|
||||
## 故障排查
|
||||
|
||||
### 服务无法启动
|
||||
|
||||
```bash
|
||||
# 查看日志
|
||||
docker logs system-prompt-optimizer
|
||||
|
||||
# 检查端口占用
|
||||
netstat -tulpn | grep 8010
|
||||
netstat -tulpn | grep 11434
|
||||
```
|
||||
|
||||
### 模型不可用
|
||||
|
||||
```bash
|
||||
# 进入容器检查
|
||||
docker exec -it system-prompt-optimizer ollama list
|
||||
|
||||
# 应该看到:
|
||||
# qwen3:14b
|
||||
# qwen3-embedding:4b
|
||||
```
|
||||
|
||||
### 性能慢
|
||||
|
||||
- 确保服务器有足够的 RAM(建议 16GB+)
|
||||
- 如果有 GPU,使用支持 GPU 的 Docker 运行时
|
||||
- 调整 `config.py` 中的 `GENERATION_POOL_SIZE`
|
||||
|
||||
## 更多信息
|
||||
|
||||
详细文档请参考:
|
||||
- `DEPLOYMENT.md`: 完整部署指南
|
||||
- `README.md`: 项目说明
|
||||
- http://localhost:8010/docs: API 文档
|
||||
|
||||
346
README.md
346
README.md
@@ -0,0 +1,346 @@
|
||||
# System Prompt Generator
|
||||
|
||||
## 功能概述
|
||||
|
||||
这是一个基于大语言模型的系统提示词(System Prompt)生成和迭代优化工具。通过简单的任务描述,自动生成高质量的系统指令,并支持基于用户选择的迭代改进。
|
||||
|
||||
### 核心功能
|
||||
|
||||
- **智能指令生成**:根据任务描述自动生成多个高质量的系统指令候选
|
||||
- **迭代式改进**:基于用户选择的指令生成改进版本,避免被拒绝的方向
|
||||
- **角色定义格式**:所有生成的指令都以角色定义开头(如"你是一个..."),符合最佳实践
|
||||
- **智能候选选择**:通过语义聚类和多样性选择,从大量候选中筛选出最具代表性的指令
|
||||
- **会话管理**:支持多个任务的并行管理和历史记录
|
||||
- **全面覆盖要求**:生成的指令全面覆盖任务的所有要求和细节,而非仅追求风格多样性
|
||||
|
||||
### 用户界面
|
||||
|
||||
- **现代化聊天界面**:类似 Google Gemini 的简洁设计
|
||||
- **侧边栏会话管理**:可折叠的侧边栏,支持多会话切换
|
||||
- **实时生成反馈**:每轮生成 5 个候选指令,用户可选择继续优化或复制使用
|
||||
- **模型选择**:支持在界面中选择不同的 LLM 模型
|
||||
|
||||
## 核心特性
|
||||
|
||||
### 1. 简单直观的工作流程
|
||||
|
||||
不同于复杂的 OPRO 算法(需要测试用例和自动评估),本工具采用简单直观的迭代改进方式:
|
||||
|
||||
- **初始生成**:输入任务描述 → 生成 5 个全面的系统指令候选
|
||||
- **迭代改进**:选择喜欢的指令 → 生成基于该指令的改进版本,同时避免被拒绝的方向
|
||||
- **无需评分**:不需要测试用例或性能评分,完全基于用户偏好进行改进
|
||||
|
||||
### 2. 高质量指令生成
|
||||
|
||||
- **角色定义格式**:所有指令以"你是一个..."开头,符合系统提示词最佳实践
|
||||
- **全面覆盖要求**:生成的指令全面覆盖任务的所有要求和细节
|
||||
- **清晰可执行**:指令清晰、具体、可执行,包含必要的行为规范和输出格式
|
||||
- **简体中文**:所有生成的指令使用简体中文
|
||||
|
||||
### 3. 性能优化
|
||||
|
||||
- **候选池大小优化**:生成 10 个候选,通过聚类选择 5 个最具多样性的
|
||||
- **智能聚类选择**:使用 AgglomerativeClustering 从候选池中选择最具代表性的指令
|
||||
- **嵌入服务回退**:Xinference → Ollama 自动回退机制,确保服务可用性
|
||||
|
||||
### 4. API 架构
|
||||
|
||||
- **核心端点**:
|
||||
- `POST /opro/create` - 创建新任务
|
||||
- `POST /opro/generate_and_evaluate` - 生成初始候选
|
||||
- `POST /opro/refine` - 基于用户选择进行迭代改进
|
||||
- `GET /opro/sessions` - 获取所有会话
|
||||
- `GET /opro/runs` - 获取所有任务
|
||||
- **会话管理**:支持多会话、多任务的并行管理
|
||||
- **向后兼容**:保留原有查询重写功能,标记为 `opro-legacy`
|
||||
|
||||
### 5. 前端界面
|
||||
|
||||
- **Gemini 风格设计**:简洁的白色/灰色配色,圆角设计,微妙的阴影效果
|
||||
- **可折叠侧边栏**:默认折叠,支持会话列表管理
|
||||
- **多行输入框**:支持多行文本输入,底部工具栏包含模型选择器
|
||||
- **候选指令卡片**:每个候选显示编号和内容,提供"继续优化"和"复制"按钮
|
||||
- **简体中文界面**:所有 UI 文本和生成的指令均使用简体中文
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 环境要求
|
||||
|
||||
- **Python** ≥ 3.10(推荐使用 conda 虚拟环境)
|
||||
- **Ollama** 本地服务及模型(如 `qwen3:8b`、`qwen3-embedding:4b`)
|
||||
- **可选**:Xinference embedding 服务
|
||||
|
||||
### 安装依赖
|
||||
|
||||
```bash
|
||||
# 创建 conda 环境(推荐)
|
||||
conda create -n opro python=3.10
|
||||
conda activate opro
|
||||
|
||||
# 安装 Python 依赖
|
||||
pip install fastapi uvicorn requests numpy scikit-learn pydantic
|
||||
```
|
||||
|
||||
### 启动 Ollama 服务
|
||||
|
||||
```bash
|
||||
# 确保 Ollama 已安装并运行
|
||||
ollama serve
|
||||
|
||||
# 拉取所需模型
|
||||
ollama pull qwen3:8b
|
||||
ollama pull qwen3-embedding:4b
|
||||
```
|
||||
|
||||
### 启动应用
|
||||
|
||||
```bash
|
||||
# 启动后端服务
|
||||
uvicorn _qwen_xinference_demo.api:app --host 127.0.0.1 --port 8010
|
||||
|
||||
# 或使用 0.0.0.0 允许外部访问
|
||||
uvicorn _qwen_xinference_demo.api:app --host 0.0.0.0 --port 8010
|
||||
```
|
||||
|
||||
### 访问界面
|
||||
|
||||
- **系统指令生成器**:http://127.0.0.1:8010/ui/opro.html
|
||||
- **传统三栏界面**:http://127.0.0.1:8010/ui/
|
||||
- **API 文档**:http://127.0.0.1:8010/docs
|
||||
- **OpenAPI JSON**:http://127.0.0.1:8010/openapi.json
|
||||
|
||||
### 使用示例
|
||||
|
||||
1. **创建新会话**:在界面点击"新建会话"或侧边栏的 + 按钮
|
||||
2. **输入任务描述**:例如"帮我写一个专业的营销文案生成助手"
|
||||
3. **查看候选指令**:系统生成 5 个全面的系统指令,每个都以角色定义开头
|
||||
4. **选择并改进**:点击喜欢的指令上的"继续优化"按钮,生成基于该指令的改进版本
|
||||
5. **复制使用**:点击"复制"按钮将指令复制到剪贴板,用于你的应用中
|
||||
|
||||
## 配置说明
|
||||
|
||||
配置文件:`config.py`
|
||||
|
||||
### 关键配置项
|
||||
|
||||
```python
|
||||
# Ollama 服务配置
|
||||
OLLAMA_HOST = "http://127.0.0.1:11434"
|
||||
DEFAULT_CHAT_MODEL = "qwen3:8b"
|
||||
DEFAULT_EMBED_MODEL = "qwen3-embedding:4b"
|
||||
|
||||
# 生成参数
|
||||
GENERATION_POOL_SIZE = 10 # 生成候选池大小(生成10个,聚类选择5个)
|
||||
TOP_K = 5 # 返回给用户的候选数量
|
||||
CLUSTER_DISTANCE_THRESHOLD = 0.15 # 聚类距离阈值
|
||||
|
||||
# Xinference 配置(可选)
|
||||
XINFERENCE_EMBED_URL = "http://127.0.0.1:9997/models/bge-base-zh/embed"
|
||||
```
|
||||
|
||||
## 项目结构
|
||||
|
||||
```
|
||||
.
|
||||
├── _qwen_xinference_demo/
|
||||
│ ├── api.py # FastAPI 主应用
|
||||
│ └── opro/
|
||||
│ ├── user_prompt_optimizer.py # OPRO 核心逻辑
|
||||
│ ├── prompt_utils.py # 元提示生成
|
||||
│ ├── session_state.py # 会话状态管理
|
||||
│ ├── ollama_client.py # Ollama 客户端
|
||||
│ └── xinference_client.py # Xinference 客户端
|
||||
├── frontend/
|
||||
│ ├── opro.html # OPRO 优化界面
|
||||
│ └── index.html # 传统三栏界面
|
||||
├── examples/
|
||||
│ ├── opro_demo.py # OPRO 功能演示
|
||||
│ └── client_demo.py # API 调用示例
|
||||
├── config.py # 全局配置
|
||||
├── API.md # API 文档
|
||||
└── README.md # 本文件
|
||||
```
|
||||
|
||||
## API 端点
|
||||
|
||||
### 会话管理
|
||||
|
||||
- `POST /opro/session/create` - 创建新会话
|
||||
- `GET /opro/sessions` - 获取所有会话
|
||||
- `GET /opro/session/{session_id}` - 获取会话详情
|
||||
|
||||
### 任务管理
|
||||
|
||||
- `POST /opro/create` - 在会话中创建新任务
|
||||
- 请求体:`{"session_id": "xxx", "task_description": "任务描述", "model_name": "qwen3:8b"}`
|
||||
- 返回:`{"run_id": "xxx", "task_description": "...", "iteration": 0}`
|
||||
|
||||
### 指令生成
|
||||
|
||||
- `POST /opro/generate_and_evaluate` - 生成初始候选指令
|
||||
- 请求体:`{"run_id": "xxx", "top_k": 5, "pool_size": 10}`
|
||||
- 返回:`{"candidates": [{"instruction": "...", "score": null}, ...]}`
|
||||
|
||||
- `POST /opro/refine` - 基于用户选择进行迭代改进
|
||||
- 请求体:`{"run_id": "xxx", "selected_instruction": "用户选择的指令", "rejected_instructions": ["被拒绝的指令1", "被拒绝的指令2"]}`
|
||||
- 返回:`{"candidates": [{"instruction": "...", "score": null}, ...], "iteration": 1}`
|
||||
|
||||
### 任务查询
|
||||
|
||||
- `GET /opro/runs` - 获取所有任务
|
||||
- `GET /opro/run/{run_id}` - 获取任务详情
|
||||
|
||||
### 传统端点(向后兼容)
|
||||
|
||||
- `POST /query` - 查询重写(首轮)
|
||||
- `POST /select` - 选择候选并回答
|
||||
- `POST /reject` - 拒绝并重新生成
|
||||
- `POST /message` - 聊天消息
|
||||
|
||||
### 通用端点
|
||||
|
||||
- `GET /health` - 健康检查
|
||||
- `GET /version` - 版本信息
|
||||
- `GET /models` - 可用模型列表
|
||||
- `POST /set_model` - 设置模型
|
||||
|
||||
详细 API 文档请访问:http://127.0.0.1:8010/docs
|
||||
|
||||
## 工作原理
|
||||
|
||||
### 初始生成流程
|
||||
|
||||
1. 用户输入任务描述(如"帮我写一个专业的营销文案生成助手")
|
||||
2. 系统使用 LLM 生成 10 个候选指令
|
||||
3. 通过语义嵌入和聚类算法选择 5 个最具多样性的候选
|
||||
4. 所有候选都以角色定义开头,全面覆盖任务要求
|
||||
|
||||
### 迭代改进流程
|
||||
|
||||
1. 用户选择喜欢的指令(如候选 #3)
|
||||
2. 系统记录被拒绝的指令(候选 #1, #2, #4, #5)
|
||||
3. 向 LLM 发送改进请求:"基于选中的指令生成改进版本,避免被拒绝指令的方向"
|
||||
4. 生成新的 10 个候选,聚类选择 5 个返回
|
||||
5. 用户可以继续迭代或复制使用
|
||||
|
||||
### 与 OPRO 的区别
|
||||
|
||||
**OPRO(原始算法)**:
|
||||
- 需要测试用例(如数学题的正确答案)
|
||||
- 自动评分(如准确率 0.73, 0.81)
|
||||
- 基于性能轨迹优化
|
||||
- 适用于有明确评估标准的任务
|
||||
|
||||
**本工具(简单迭代改进)**:
|
||||
- 不需要测试用例
|
||||
- 不需要自动评分
|
||||
- 基于用户偏好改进
|
||||
- 适用于任意通用任务
|
||||
|
||||
## 常见问题
|
||||
|
||||
### 1. 无法连接 Ollama 服务
|
||||
|
||||
确保 Ollama 服务正在运行:
|
||||
```bash
|
||||
ollama serve
|
||||
```
|
||||
|
||||
检查配置文件中的 `OLLAMA_HOST` 是否正确。
|
||||
|
||||
### 2. 模型不可用
|
||||
|
||||
通过 `/models` 端点查看可用模型列表,使用 `/set_model` 切换模型。
|
||||
|
||||
### 3. 生成速度慢
|
||||
|
||||
- 调整 `GENERATION_POOL_SIZE` 减少候选数量(如改为 6,返回 3 个)
|
||||
- 使用更小的模型(如 `qwen3:4b`)
|
||||
- 确保 Ollama 使用 GPU 加速
|
||||
|
||||
### 4. 生成的指令质量不高
|
||||
|
||||
- 提供更详细的任务描述
|
||||
- 多次迭代改进,选择最好的继续优化
|
||||
- 尝试不同的模型
|
||||
|
||||
### 5. 界面显示异常
|
||||
|
||||
硬刷新浏览器缓存:
|
||||
- **Mac**: `Cmd + Shift + R`
|
||||
- **Windows/Linux**: `Ctrl + Shift + R`
|
||||
|
||||
---
|
||||
|
||||
<details>
|
||||
<summary><b>原始 README(点击展开)</b></summary>
|
||||
|
||||
- 项目简介
|
||||
|
||||
- OPRO Prompt Optimizer:面向提示优化的交互式系统,支持多轮拒选/再生成、语义聚类去重与 Top‑K 代表选择。
|
||||
- 后端 FastAPI 提供 REST 接口,前端三栏 UI 便于会话管理与候选挑选。
|
||||
- 架构概览
|
||||
|
||||
- Frontend /ui/ → POST /query 首轮候选 → POST /select 选择并回答 → POST /reject 再生成 → POST /query_from_message 基于最近消息优化 → POST /message 聊天
|
||||
- OPRO 流程:指令构造 → Qwen 批量生成 → Embedding(Xinference→Ollama 回退)→ 聚类去重 → Top‑K
|
||||
- 核心实现位置: _qwen_xinference_demo/opro/user_prompt_optimizer.py:45-54 (候选生成)、 _qwen_xinference_demo/opro/xinference_client.py:7-28 (embedding 回退)
|
||||
- 环境与依赖
|
||||
|
||||
- Python ≥ 3.10(建议使用 conda 虚拟环境)
|
||||
- 必需:Ollama 本地服务与模型(如 qwen3:8b , qwen3-embedding:4b )
|
||||
- 可选:Xinference embedding 服务( http://127.0.0.1:9997/models/bge-base-zh/embed )
|
||||
- Python 依赖: fastapi 、 uvicorn 、 requests 、 numpy 、 scikit-learn 、 pydantic
|
||||
- 安装与启动
|
||||
|
||||
- 安装依赖
|
||||
- pip install fastapi uvicorn requests numpy scikit-learn pydantic
|
||||
- 启动后端服务
|
||||
- uvicorn _qwen_xinference_demo.api:app --host 0.0.0.0 --port 8010
|
||||
- 访问页面
|
||||
- 前端三栏 UI: http://127.0.0.1:8010/ui/
|
||||
- OpenAPI 文档: http://127.0.0.1:8010/docs
|
||||
- OpenAPI JSON: http://127.0.0.1:8010/openapi.json
|
||||
- 配置
|
||||
|
||||
- 文件: config.py
|
||||
- 关键项
|
||||
- APP_TITLE 、 APP_DESCRIPTION 、 APP_VERSION 、 APP_CONTACT (应用元信息,见 _qwen_xinference_demo/api.py:14-26 )
|
||||
- OLLAMA_HOST 、 OLLAMA_GENERATE_URL 、 OLLAMA_TAGS_URL (Ollama 端点)
|
||||
- DEFAULT_CHAT_MODEL 、 DEFAULT_EMBED_MODEL (默认模型,用于 _qwen_xinference_demo/opro/ollama_client.py:4-7 与 _qwen_xinference_demo/opro/xinference_client.py:1-6,20-21 )
|
||||
- XINFERENCE_EMBED_URL (优先 embedding 端点)
|
||||
- TOP_K 、 CLUSTER_DISTANCE_THRESHOLD (候选选择参数,引用 _qwen_xinference_demo/opro/user_prompt_optimizer.py:19,45 )
|
||||
- 统一响应与错误处理
|
||||
|
||||
- 成功: {"success": true, "data": {...}}
|
||||
- 失败: {"success": false, "error": "...", "error_code": "..."} ,状态码保持 HTTP 值
|
||||
- 应用级异常: AppException(status_code, detail, error_code) _qwen_xinference_demo/api.py:23-39
|
||||
- 示例:会话不存在抛出 SESSION_NOT_FOUND ,Ollama 调用异常抛出 OLLAMA_ERROR
|
||||
- API 与示例
|
||||
|
||||
- 全量端点与示例:见 API.md
|
||||
- 健康与版本
|
||||
- GET /health 返回 {status, version} _qwen_xinference_demo/api.py:129-134
|
||||
- GET /version 返回 {version} _qwen_xinference_demo/api.py:135-138
|
||||
- 示例脚本
|
||||
- 入口: examples/client_demo.py
|
||||
- 功能:健康检查 → 创建会话 → 选择候选 → 继续优化 → 聊天 → 会话列表
|
||||
- 目录结构
|
||||
|
||||
- /_qwen_xinference_demo/api.py :FastAPI 主应用与路由
|
||||
- /_qwen_xinference_demo/opro/user_prompt_optimizer.py :OPRO 候选生成与聚类选择
|
||||
- /_qwen_xinference_demo/opro/xinference_client.py :Embedding(Xinference→Ollama 回退)
|
||||
- /_qwen_xinference_demo/opro/ollama_client.py :Ollama 调用与模型列表
|
||||
- /_qwen_xinference_demo/opro/session_state.py :会话态管理
|
||||
- /frontend/index.html :三栏 UI 页面
|
||||
- /API.md :接口文档
|
||||
- /examples/client_demo.py :示例调用脚本
|
||||
- /config.py :全局配置
|
||||
- 常见问题
|
||||
|
||||
- 无法访问 /ui/react :使用 /ui/ ,React 示例仅作演示入口 _qwen_xinference_demo/api.py:133-144
|
||||
- 模型不可用: /models 查看列表并通过 /set_model 应用;错误返回 MODEL_NOT_AVAILABLE
|
||||
- 第二轮无相关候选:使用 POST /query_from_message 基于最近消息再生候选 _qwen_xinference_demo/api.py:193-206
|
||||
- 立即回答诉求:用 POST /answer 先答后给候选 _qwen_xinference_demo/api.py:211-219
|
||||
- 端口与地址访问差异:在启动命令中明确 --host 0.0.0.0 --port 8010 ,本地浏览器建议访问 127.0.0.1
|
||||
|
||||
</details>
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -2,14 +2,32 @@ from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.responses import RedirectResponse, FileResponse, JSONResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Tuple, Optional
|
||||
import config
|
||||
|
||||
# Legacy session management (query rewriting)
|
||||
from .opro.session_state import create_session, get_session, update_session_add_candidates, log_user_choice
|
||||
from .opro.session_state import log_user_reject
|
||||
from .opro.session_state import set_selected_prompt, log_chat_message
|
||||
from .opro.session_state import set_session_model
|
||||
from .opro.session_state import USER_FEEDBACK_LOG
|
||||
|
||||
# True OPRO session management
|
||||
from .opro.session_state import (
|
||||
create_opro_session, get_opro_session, list_opro_sessions,
|
||||
create_opro_run, get_opro_run, update_opro_iteration,
|
||||
add_opro_evaluation, get_opro_trajectory, set_opro_test_cases,
|
||||
complete_opro_run, list_opro_runs
|
||||
)
|
||||
|
||||
# Optimization functions
|
||||
from .opro.user_prompt_optimizer import generate_candidates
|
||||
from .opro.user_prompt_optimizer import (
|
||||
generate_system_instruction_candidates,
|
||||
evaluate_system_instruction,
|
||||
refine_instruction_candidates
|
||||
)
|
||||
|
||||
from .opro.ollama_client import call_qwen
|
||||
from .opro.ollama_client import list_models
|
||||
|
||||
@@ -23,8 +41,9 @@ app = FastAPI(
|
||||
openapi_tags=[
|
||||
{"name": "health", "description": "健康检查"},
|
||||
{"name": "models", "description": "模型列表与设置"},
|
||||
{"name": "sessions", "description": "会话管理"},
|
||||
{"name": "opro", "description": "提示优化候选生成与选择/拒绝"},
|
||||
{"name": "sessions", "description": "会话管理(旧版查询重写)"},
|
||||
{"name": "opro-legacy", "description": "旧版提示优化(查询重写)"},
|
||||
{"name": "opro-true", "description": "真正的OPRO(系统指令优化)"},
|
||||
{"name": "chat", "description": "会话聊天"},
|
||||
{"name": "ui", "description": "静态页面"}
|
||||
]
|
||||
@@ -89,14 +108,79 @@ class SetModelReq(BaseModel):
|
||||
session_id: str
|
||||
model_name: str
|
||||
|
||||
@app.post("/start", tags=["opro"])
|
||||
|
||||
# ============================================================================
|
||||
# TRUE OPRO REQUEST MODELS
|
||||
# ============================================================================
|
||||
|
||||
class TestCase(BaseModel):
|
||||
"""A single test case for OPRO evaluation."""
|
||||
input: str
|
||||
expected_output: str
|
||||
|
||||
|
||||
class CreateOPRORunReq(BaseModel):
|
||||
"""Request to create a new OPRO optimization run."""
|
||||
task_description: str
|
||||
test_cases: Optional[List[TestCase]] = None
|
||||
model_name: Optional[str] = None
|
||||
session_id: Optional[str] = None # Optional session to associate with
|
||||
|
||||
|
||||
class OPROIterateReq(BaseModel):
|
||||
"""Request to run one OPRO iteration."""
|
||||
run_id: str
|
||||
top_k: Optional[int] = None
|
||||
|
||||
|
||||
class OPROEvaluateReq(BaseModel):
|
||||
"""Request to evaluate a system instruction."""
|
||||
run_id: str
|
||||
instruction: str
|
||||
|
||||
|
||||
class OPROAddTestCasesReq(BaseModel):
|
||||
"""Request to add test cases to an OPRO run."""
|
||||
run_id: str
|
||||
test_cases: List[TestCase]
|
||||
|
||||
|
||||
class OPROGenerateAndEvaluateReq(BaseModel):
|
||||
"""Request to generate and auto-evaluate candidates (for chat-like UX)."""
|
||||
run_id: str
|
||||
top_k: Optional[int] = None
|
||||
pool_size: Optional[int] = None
|
||||
auto_evaluate: Optional[bool] = True # If False, use diversity-based selection only
|
||||
|
||||
|
||||
class OPROExecuteReq(BaseModel):
|
||||
"""Request to execute a system instruction with user input."""
|
||||
instruction: str
|
||||
user_input: str
|
||||
model_name: Optional[str] = None
|
||||
|
||||
|
||||
class OPRORefineReq(BaseModel):
|
||||
"""Request to refine based on selected instruction (simple iterative refinement, NOT OPRO)."""
|
||||
run_id: str
|
||||
selected_instruction: str
|
||||
rejected_instructions: List[str]
|
||||
top_k: Optional[int] = None
|
||||
pool_size: Optional[int] = None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# LEGACY ENDPOINTS (Query Rewriting - NOT true OPRO)
|
||||
# ============================================================================
|
||||
|
||||
@app.post("/start", tags=["opro-legacy"])
|
||||
def start(req: StartReq):
|
||||
sid = create_session(req.query)
|
||||
cands = generate_candidates(req.query, [], model_name=get_session(sid).get("model_name"))
|
||||
update_session_add_candidates(sid, cands)
|
||||
return ok({"session_id": sid, "round": 0, "candidates": cands})
|
||||
|
||||
@app.post("/next", tags=["opro"])
|
||||
@app.post("/next", tags=["opro-legacy"])
|
||||
def next_round(req: NextReq):
|
||||
s = get_session(req.session_id)
|
||||
if not s:
|
||||
@@ -110,7 +194,7 @@ def next_round(req: NextReq):
|
||||
update_session_add_candidates(req.session_id, cands)
|
||||
return ok({"session_id": req.session_id, "round": s["round"], "candidates": cands})
|
||||
|
||||
@app.post("/select", tags=["opro"])
|
||||
@app.post("/select", tags=["opro-legacy"])
|
||||
def select(req: SelectReq):
|
||||
s = get_session(req.session_id)
|
||||
if not s:
|
||||
@@ -138,7 +222,7 @@ def select(req: SelectReq):
|
||||
pass
|
||||
return ok({"prompt": req.choice, "answer": ans})
|
||||
|
||||
@app.post("/reject", tags=["opro"])
|
||||
@app.post("/reject", tags=["opro-legacy"])
|
||||
def reject(req: RejectReq):
|
||||
s = get_session(req.session_id)
|
||||
if not s:
|
||||
@@ -151,7 +235,7 @@ class QueryReq(BaseModel):
|
||||
query: str
|
||||
session_id: str | None = None
|
||||
|
||||
@app.post("/query", tags=["opro"])
|
||||
@app.post("/query", tags=["opro-legacy"])
|
||||
def query(req: QueryReq):
|
||||
if req.session_id:
|
||||
s = get_session(req.session_id)
|
||||
@@ -240,7 +324,7 @@ def message(req: MessageReq):
|
||||
class QueryFromMsgReq(BaseModel):
|
||||
session_id: str
|
||||
|
||||
@app.post("/query_from_message", tags=["opro"])
|
||||
@app.post("/query_from_message", tags=["opro-legacy"])
|
||||
def query_from_message(req: QueryFromMsgReq):
|
||||
s = get_session(req.session_id)
|
||||
if not s:
|
||||
@@ -258,7 +342,7 @@ def query_from_message(req: QueryFromMsgReq):
|
||||
class AnswerReq(BaseModel):
|
||||
query: str
|
||||
|
||||
@app.post("/answer", tags=["opro"])
|
||||
@app.post("/answer", tags=["opro-legacy"])
|
||||
def answer(req: AnswerReq):
|
||||
sid = create_session(req.query)
|
||||
log_chat_message(sid, "user", req.query)
|
||||
@@ -282,3 +366,384 @@ def set_model(req: SetModelReq):
|
||||
raise AppException(400, f"model not available: {req.model_name}", "MODEL_NOT_AVAILABLE")
|
||||
set_session_model(req.session_id, req.model_name)
|
||||
return ok({"session_id": req.session_id, "model_name": req.model_name})
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# TRUE OPRO ENDPOINTS (System Instruction Optimization)
|
||||
# ============================================================================
|
||||
|
||||
# Session Management
|
||||
|
||||
@app.post("/opro/session/create", tags=["opro-true"])
|
||||
def opro_create_session(session_name: str = None):
|
||||
"""
|
||||
Create a new OPRO session that can contain multiple runs.
|
||||
"""
|
||||
session_id = create_opro_session(session_name=session_name)
|
||||
session = get_opro_session(session_id)
|
||||
|
||||
return ok({
|
||||
"session_id": session_id,
|
||||
"session_name": session["session_name"],
|
||||
"num_runs": len(session["run_ids"])
|
||||
})
|
||||
|
||||
|
||||
@app.get("/opro/sessions", tags=["opro-true"])
|
||||
def opro_list_sessions():
|
||||
"""
|
||||
List all OPRO sessions.
|
||||
"""
|
||||
sessions = list_opro_sessions()
|
||||
return ok({"sessions": sessions})
|
||||
|
||||
|
||||
@app.get("/opro/session/{session_id}", tags=["opro-true"])
|
||||
def opro_get_session(session_id: str):
|
||||
"""
|
||||
Get detailed information about an OPRO session.
|
||||
"""
|
||||
session = get_opro_session(session_id)
|
||||
if not session:
|
||||
raise AppException(404, "Session not found", "SESSION_NOT_FOUND")
|
||||
|
||||
# Get all runs in this session
|
||||
runs = list_opro_runs(session_id=session_id)
|
||||
|
||||
return ok({
|
||||
"session_id": session_id,
|
||||
"session_name": session["session_name"],
|
||||
"created_at": session["created_at"],
|
||||
"num_runs": len(session["run_ids"]),
|
||||
"runs": runs
|
||||
})
|
||||
|
||||
|
||||
# Run Management
|
||||
|
||||
@app.post("/opro/create", tags=["opro-true"])
|
||||
def opro_create_run(req: CreateOPRORunReq):
|
||||
"""
|
||||
Create a new OPRO optimization run.
|
||||
|
||||
This starts a new system instruction optimization process for a given task.
|
||||
Optionally can be associated with a session.
|
||||
"""
|
||||
# Convert test cases from Pydantic models to tuples
|
||||
test_cases = None
|
||||
if req.test_cases:
|
||||
test_cases = [(tc.input, tc.expected_output) for tc in req.test_cases]
|
||||
|
||||
run_id = create_opro_run(
|
||||
task_description=req.task_description,
|
||||
test_cases=test_cases,
|
||||
model_name=req.model_name,
|
||||
session_id=req.session_id
|
||||
)
|
||||
|
||||
run = get_opro_run(run_id)
|
||||
|
||||
return ok({
|
||||
"run_id": run_id,
|
||||
"task_description": run["task_description"],
|
||||
"num_test_cases": len(run["test_cases"]),
|
||||
"iteration": run["iteration"],
|
||||
"status": run["status"],
|
||||
"session_id": run.get("session_id")
|
||||
})
|
||||
|
||||
|
||||
@app.post("/opro/iterate", tags=["opro-true"])
|
||||
def opro_iterate(req: OPROIterateReq):
|
||||
"""
|
||||
Run one OPRO iteration: generate new system instruction candidates.
|
||||
|
||||
This generates optimized system instructions based on the performance trajectory.
|
||||
"""
|
||||
run = get_opro_run(req.run_id)
|
||||
if not run:
|
||||
raise AppException(404, "OPRO run not found", "RUN_NOT_FOUND")
|
||||
|
||||
# Get trajectory for optimization
|
||||
trajectory = get_opro_trajectory(req.run_id)
|
||||
|
||||
# Generate candidates
|
||||
top_k = req.top_k or config.TOP_K
|
||||
try:
|
||||
candidates = generate_system_instruction_candidates(
|
||||
task_description=run["task_description"],
|
||||
trajectory=trajectory if trajectory else None,
|
||||
top_k=top_k,
|
||||
model_name=run["model_name"]
|
||||
)
|
||||
except Exception as e:
|
||||
raise AppException(500, f"Failed to generate candidates: {e}", "GENERATION_ERROR")
|
||||
|
||||
# Update run with new candidates
|
||||
update_opro_iteration(req.run_id, candidates)
|
||||
|
||||
return ok({
|
||||
"run_id": req.run_id,
|
||||
"iteration": run["iteration"] + 1,
|
||||
"candidates": candidates,
|
||||
"num_candidates": len(candidates),
|
||||
"best_score": run["best_score"]
|
||||
})
|
||||
|
||||
|
||||
@app.post("/opro/evaluate", tags=["opro-true"])
|
||||
def opro_evaluate(req: OPROEvaluateReq):
|
||||
"""
|
||||
Evaluate a system instruction on the test cases.
|
||||
|
||||
This scores the instruction and updates the performance trajectory.
|
||||
If no test cases are defined, uses a default score of 0.5 to indicate user selection.
|
||||
"""
|
||||
run = get_opro_run(req.run_id)
|
||||
if not run:
|
||||
raise AppException(404, "OPRO run not found", "RUN_NOT_FOUND")
|
||||
|
||||
# Evaluate the instruction if test cases exist
|
||||
if run["test_cases"] and len(run["test_cases"]) > 0:
|
||||
try:
|
||||
score = evaluate_system_instruction(
|
||||
system_instruction=req.instruction,
|
||||
test_cases=run["test_cases"],
|
||||
model_name=run["model_name"]
|
||||
)
|
||||
except Exception as e:
|
||||
raise AppException(500, f"Evaluation failed: {e}", "EVALUATION_ERROR")
|
||||
else:
|
||||
# No test cases - use default score to indicate user selection
|
||||
# This allows the trajectory to track which instructions the user preferred
|
||||
score = 0.5
|
||||
|
||||
# Add to trajectory
|
||||
add_opro_evaluation(req.run_id, req.instruction, score)
|
||||
|
||||
# Get updated run info
|
||||
run = get_opro_run(req.run_id)
|
||||
|
||||
return ok({
|
||||
"run_id": req.run_id,
|
||||
"instruction": req.instruction,
|
||||
"score": score,
|
||||
"best_score": run["best_score"],
|
||||
"is_new_best": score == run["best_score"] and score > 0,
|
||||
"has_test_cases": len(run["test_cases"]) > 0
|
||||
})
|
||||
|
||||
|
||||
@app.get("/opro/runs", tags=["opro-true"])
|
||||
def opro_list_runs():
|
||||
"""
|
||||
List all OPRO optimization runs.
|
||||
"""
|
||||
runs = list_opro_runs()
|
||||
return ok({"runs": runs, "total": len(runs)})
|
||||
|
||||
|
||||
@app.get("/opro/run/{run_id}", tags=["opro-true"])
|
||||
def opro_get_run(run_id: str):
|
||||
"""
|
||||
Get detailed information about an OPRO run.
|
||||
"""
|
||||
run = get_opro_run(run_id)
|
||||
if not run:
|
||||
raise AppException(404, "OPRO run not found", "RUN_NOT_FOUND")
|
||||
|
||||
# Get sorted trajectory
|
||||
trajectory = get_opro_trajectory(run_id)
|
||||
|
||||
return ok({
|
||||
"run_id": run_id,
|
||||
"task_description": run["task_description"],
|
||||
"iteration": run["iteration"],
|
||||
"status": run["status"],
|
||||
"best_score": run["best_score"],
|
||||
"best_instruction": run["best_instruction"],
|
||||
"num_test_cases": len(run["test_cases"]),
|
||||
"test_cases": [{"input": tc[0], "expected_output": tc[1]} for tc in run["test_cases"]],
|
||||
"trajectory": [{"instruction": inst, "score": score} for inst, score in trajectory[:10]], # Top 10
|
||||
"current_candidates": run["current_candidates"]
|
||||
})
|
||||
|
||||
|
||||
@app.post("/opro/test_cases", tags=["opro-true"])
|
||||
def opro_add_test_cases(req: OPROAddTestCasesReq):
|
||||
"""
|
||||
Add or update test cases for an OPRO run.
|
||||
"""
|
||||
run = get_opro_run(req.run_id)
|
||||
if not run:
|
||||
raise AppException(404, "OPRO run not found", "RUN_NOT_FOUND")
|
||||
|
||||
# Convert test cases
|
||||
test_cases = [(tc.input, tc.expected_output) for tc in req.test_cases]
|
||||
|
||||
# Update test cases
|
||||
set_opro_test_cases(req.run_id, test_cases)
|
||||
|
||||
return ok({
|
||||
"run_id": req.run_id,
|
||||
"num_test_cases": len(test_cases),
|
||||
"test_cases": [{"input": tc[0], "expected_output": tc[1]} for tc in test_cases]
|
||||
})
|
||||
|
||||
|
||||
@app.post("/opro/generate_and_evaluate", tags=["opro-true"])
|
||||
def opro_generate_and_evaluate(req: OPROGenerateAndEvaluateReq):
|
||||
"""
|
||||
Generate candidates and auto-evaluate them (for chat-like UX).
|
||||
|
||||
This is the main endpoint for the chat interface. It:
|
||||
1. Generates candidates based on trajectory
|
||||
2. Auto-evaluates them (if test cases exist and auto_evaluate=True)
|
||||
3. Returns top-k sorted by score (or diversity if no evaluation)
|
||||
"""
|
||||
run = get_opro_run(req.run_id)
|
||||
if not run:
|
||||
raise AppException(404, "OPRO run not found", "RUN_NOT_FOUND")
|
||||
|
||||
top_k = req.top_k or config.TOP_K
|
||||
pool_size = req.pool_size or config.GENERATION_POOL_SIZE
|
||||
|
||||
# Get trajectory for optimization
|
||||
trajectory = get_opro_trajectory(req.run_id)
|
||||
|
||||
# Generate candidates
|
||||
try:
|
||||
candidates = generate_system_instruction_candidates(
|
||||
task_description=run["task_description"],
|
||||
trajectory=trajectory if trajectory else None,
|
||||
top_k=pool_size, # Generate pool_size candidates first
|
||||
pool_size=pool_size,
|
||||
model_name=run["model_name"]
|
||||
)
|
||||
except Exception as e:
|
||||
raise AppException(500, f"Failed to generate candidates: {e}", "GENERATION_ERROR")
|
||||
|
||||
# Decide whether to evaluate
|
||||
should_evaluate = req.auto_evaluate and len(run["test_cases"]) > 0
|
||||
|
||||
if should_evaluate:
|
||||
# Auto-evaluate all candidates
|
||||
scored_candidates = []
|
||||
for candidate in candidates:
|
||||
try:
|
||||
score = evaluate_system_instruction(
|
||||
system_instruction=candidate,
|
||||
test_cases=run["test_cases"],
|
||||
model_name=run["model_name"]
|
||||
)
|
||||
scored_candidates.append({"instruction": candidate, "score": score})
|
||||
|
||||
# Add to trajectory
|
||||
add_opro_evaluation(req.run_id, candidate, score)
|
||||
except Exception as e:
|
||||
# If evaluation fails, assign score 0
|
||||
scored_candidates.append({"instruction": candidate, "score": 0.0})
|
||||
|
||||
# Sort by score (highest first)
|
||||
scored_candidates.sort(key=lambda x: x["score"], reverse=True)
|
||||
|
||||
# Return top-k
|
||||
top_candidates = scored_candidates[:top_k]
|
||||
|
||||
# Update iteration
|
||||
update_opro_iteration(req.run_id, [c["instruction"] for c in top_candidates])
|
||||
|
||||
return ok({
|
||||
"run_id": req.run_id,
|
||||
"candidates": top_candidates,
|
||||
"iteration": run["iteration"] + 1,
|
||||
"evaluated": True,
|
||||
"best_score": run["best_score"]
|
||||
})
|
||||
else:
|
||||
# No evaluation - use diversity-based selection (already done by clustering)
|
||||
# Just return the candidates without scores
|
||||
top_candidates = [
|
||||
{"instruction": candidate, "score": None}
|
||||
for candidate in candidates[:top_k]
|
||||
]
|
||||
|
||||
# Update iteration
|
||||
update_opro_iteration(req.run_id, [c["instruction"] for c in top_candidates])
|
||||
|
||||
return ok({
|
||||
"run_id": req.run_id,
|
||||
"candidates": top_candidates,
|
||||
"iteration": run["iteration"] + 1,
|
||||
"evaluated": False,
|
||||
"best_score": run["best_score"]
|
||||
})
|
||||
|
||||
|
||||
@app.post("/opro/execute", tags=["opro-true"])
|
||||
def opro_execute(req: OPROExecuteReq):
|
||||
"""
|
||||
Execute a system instruction with user input.
|
||||
|
||||
This uses the selected instruction as a system prompt and calls the LLM.
|
||||
"""
|
||||
try:
|
||||
# Construct full prompt with system instruction
|
||||
full_prompt = f"{req.instruction}\n\n{req.user_input}"
|
||||
|
||||
# Call LLM
|
||||
response = call_qwen(
|
||||
full_prompt,
|
||||
temperature=0.2,
|
||||
max_tokens=1024,
|
||||
model_name=req.model_name
|
||||
)
|
||||
|
||||
return ok({
|
||||
"instruction": req.instruction,
|
||||
"user_input": req.user_input,
|
||||
"response": response
|
||||
})
|
||||
except Exception as e:
|
||||
raise AppException(500, f"Execution failed: {e}", "EXECUTION_ERROR")
|
||||
|
||||
|
||||
@app.post("/opro/refine", tags=["opro-true"])
|
||||
def opro_refine(req: OPRORefineReq):
|
||||
"""
|
||||
Simple iterative refinement based on user selection (NOT OPRO).
|
||||
|
||||
This generates new candidates based on the selected instruction while avoiding rejected ones.
|
||||
No scoring, no trajectory - just straightforward refinement based on user preference.
|
||||
"""
|
||||
run = get_opro_run(req.run_id)
|
||||
if not run:
|
||||
raise AppException(404, "OPRO run not found", "RUN_NOT_FOUND")
|
||||
|
||||
top_k = req.top_k or config.TOP_K
|
||||
pool_size = req.pool_size or config.GENERATION_POOL_SIZE
|
||||
|
||||
try:
|
||||
candidates = refine_instruction_candidates(
|
||||
task_description=run["task_description"],
|
||||
selected_instruction=req.selected_instruction,
|
||||
rejected_instructions=req.rejected_instructions,
|
||||
top_k=top_k,
|
||||
pool_size=pool_size,
|
||||
model_name=run["model_name"]
|
||||
)
|
||||
|
||||
# Update iteration counter
|
||||
update_opro_iteration(req.run_id, candidates)
|
||||
|
||||
# Get updated run info
|
||||
run = get_opro_run(req.run_id)
|
||||
|
||||
return ok({
|
||||
"run_id": req.run_id,
|
||||
"iteration": run["iteration"],
|
||||
"candidates": [{"instruction": c, "score": None} for c in candidates],
|
||||
"task_description": run["task_description"]
|
||||
})
|
||||
except Exception as e:
|
||||
raise AppException(500, f"Refinement failed: {e}", "REFINEMENT_ERROR")
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,4 +1,14 @@
|
||||
from typing import List, Tuple
|
||||
|
||||
# ============================================================================
|
||||
# OLD FUNCTIONS (Query Rewriting - NOT true OPRO, kept for compatibility)
|
||||
# ============================================================================
|
||||
|
||||
def refine_instruction(query: str) -> str:
|
||||
"""
|
||||
LEGACY: Generates query rewrites (NOT true OPRO).
|
||||
This is query expansion, not system instruction optimization.
|
||||
"""
|
||||
return f"""
|
||||
你是一个“问题澄清与重写助手”。
|
||||
请根据用户的原始问题:
|
||||
@@ -7,6 +17,9 @@ def refine_instruction(query: str) -> str:
|
||||
"""
|
||||
|
||||
def refine_instruction_with_history(query: str, rejected_list: list) -> str:
|
||||
"""
|
||||
LEGACY: Generates query rewrites with rejection history (NOT true OPRO).
|
||||
"""
|
||||
rejected_text = "\n".join(f"- {r}" for r in rejected_list) if rejected_list else ""
|
||||
return f"""
|
||||
你是一个“问题澄清与重写助手”。
|
||||
@@ -18,3 +31,158 @@ def refine_instruction_with_history(query: str, rejected_list: list) -> str:
|
||||
|
||||
请从新的角度重新生成至少20条不同的改写问题,每条单独一行。
|
||||
"""
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# TRUE OPRO FUNCTIONS (System Instruction Optimization)
|
||||
# ============================================================================
|
||||
|
||||
def generate_initial_system_instruction_candidates(task_description: str, pool_size: int = None) -> str:
|
||||
"""
|
||||
TRUE OPRO: Generates initial candidate System Instructions for a new OPRO run.
|
||||
|
||||
Args:
|
||||
task_description: Description of the task the LLM should perform
|
||||
pool_size: Number of candidates to generate (defaults to config.GENERATION_POOL_SIZE)
|
||||
|
||||
Returns:
|
||||
Meta-prompt that instructs the optimizer LLM to generate system instruction candidates
|
||||
"""
|
||||
import config
|
||||
pool_size = pool_size or config.GENERATION_POOL_SIZE
|
||||
|
||||
return f"""
|
||||
你是一个"系统指令生成助手"。
|
||||
目标任务描述:
|
||||
【{task_description}】
|
||||
|
||||
请根据以上任务,生成 {pool_size} 条高质量、全面的"System Instruction"候选指令。
|
||||
|
||||
要求:
|
||||
1. 每条指令必须以角色定义开头(例如:"你是一个..."、"你是..."等)
|
||||
2. 每条指令必须全面覆盖任务的所有要求和细节
|
||||
3. 指令应清晰、具体、可执行,能够有效指导LLM完成任务
|
||||
4. 确保指令包含必要的行为规范、输出格式、注意事项等
|
||||
5. 每条指令单独成行,不包含编号或额外说明
|
||||
6. 所有生成的指令必须使用简体中文
|
||||
|
||||
生成 {pool_size} 条指令:
|
||||
"""
|
||||
|
||||
|
||||
def generate_optimized_system_instruction(
|
||||
task_description: str,
|
||||
trajectory: List[Tuple[str, float]],
|
||||
pool_size: int = None
|
||||
) -> str:
|
||||
"""
|
||||
TRUE OPRO: Analyzes performance trajectory and generates optimized System Instructions.
|
||||
|
||||
This is the core OPRO function that uses an LLM as an optimizer to improve
|
||||
system instructions based on historical performance scores.
|
||||
|
||||
Args:
|
||||
task_description: Description of the task the LLM should perform
|
||||
trajectory: List of (instruction, score) tuples, sorted by score (highest first)
|
||||
pool_size: Number of candidates to generate (defaults to config.GENERATION_POOL_SIZE)
|
||||
|
||||
Returns:
|
||||
Meta-prompt that instructs the optimizer LLM to generate better system instructions
|
||||
"""
|
||||
import config
|
||||
pool_size = pool_size or config.GENERATION_POOL_SIZE
|
||||
|
||||
if not trajectory:
|
||||
# If no trajectory, fall back to initial generation
|
||||
return generate_initial_system_instruction_candidates(task_description, pool_size)
|
||||
|
||||
# Format the trajectory for the Optimizer LLM
|
||||
formatted_history = "\n".join(
|
||||
f"--- Instruction Score: {score:.4f}\n{instruction}"
|
||||
for instruction, score in trajectory
|
||||
)
|
||||
|
||||
# Determine the current highest score to set the optimization goal
|
||||
highest_score = max(score for _, score in trajectory)
|
||||
|
||||
# Construct the Meta-Prompt (The OPRO Instruction)
|
||||
return f"""
|
||||
你是一个"System Prompt 优化器"。
|
||||
你的任务是改进一个LLM的系统指令,以最大化其在以下任务中的性能:
|
||||
【{task_description}】
|
||||
|
||||
---
|
||||
**历史性能轨迹 (Instructions and Scores):**
|
||||
{formatted_history}
|
||||
---
|
||||
**当前最高得分: {highest_score:.4f}**
|
||||
|
||||
请分析得分最高的指令的特点和得分最低指令的缺陷。
|
||||
然后,生成 {pool_size} 条新的、有潜力超越 {highest_score:.4f} 分的System Instruction。
|
||||
|
||||
要求:
|
||||
1. 每条指令必须以角色定义开头(例如:"你是一个..."、"你是..."等)
|
||||
2. 每条指令必须全面覆盖任务的所有要求和细节
|
||||
3. 结合高分指令的优点,避免低分指令的缺陷
|
||||
4. 指令应清晰、具体、可执行,能够有效指导LLM完成任务
|
||||
5. 每条指令单独成行,不包含编号或额外说明
|
||||
6. 所有生成的指令必须使用简体中文
|
||||
|
||||
生成 {pool_size} 条优化后的指令:
|
||||
"""
|
||||
|
||||
|
||||
def refine_based_on_selection(
|
||||
task_description: str,
|
||||
selected_instruction: str,
|
||||
rejected_instructions: List[str],
|
||||
pool_size: int = None
|
||||
) -> str:
|
||||
"""
|
||||
Simple refinement: Generate variations based on selected instruction while avoiding rejected ones.
|
||||
|
||||
This is NOT OPRO - it's straightforward iterative refinement based on user preference.
|
||||
No scoring, no trajectory, just: "I like this one, give me more like it (but not like those)."
|
||||
|
||||
Args:
|
||||
task_description: Description of the task
|
||||
selected_instruction: The instruction the user selected
|
||||
rejected_instructions: The instructions the user didn't select
|
||||
pool_size: Number of new candidates to generate
|
||||
|
||||
Returns:
|
||||
Prompt for generating refined candidates
|
||||
"""
|
||||
import config
|
||||
pool_size = pool_size or config.GENERATION_POOL_SIZE
|
||||
|
||||
rejected_text = ""
|
||||
if rejected_instructions:
|
||||
rejected_formatted = "\n".join(f"- {inst}" for inst in rejected_instructions)
|
||||
rejected_text = f"""
|
||||
**用户未选择的指令(避免这些方向):**
|
||||
{rejected_formatted}
|
||||
"""
|
||||
|
||||
return f"""
|
||||
你是一个"System Prompt 改进助手"。
|
||||
目标任务描述:
|
||||
【{task_description}】
|
||||
|
||||
**用户选择的指令(基于此改进):**
|
||||
{selected_instruction}
|
||||
{rejected_text}
|
||||
|
||||
请基于用户选择的指令,生成 {pool_size} 条改进版本。
|
||||
|
||||
要求:
|
||||
1. 每条指令必须以角色定义开头(例如:"你是一个..."、"你是..."等)
|
||||
2. 保留用户选择指令的核心优点
|
||||
3. 每条指令必须全面覆盖任务的所有要求和细节
|
||||
4. 指令应清晰、具体、可执行,能够有效指导LLM完成任务
|
||||
5. 避免与未选择指令相似的方向
|
||||
6. 每条指令单独成行,不包含编号或额外说明
|
||||
7. 所有生成的指令必须使用简体中文
|
||||
|
||||
生成 {pool_size} 条改进后的指令:
|
||||
"""
|
||||
|
||||
@@ -1,8 +1,14 @@
|
||||
import uuid
|
||||
from typing import List, Tuple, Dict, Any
|
||||
|
||||
# Legacy session storage (for query rewriting)
|
||||
SESSIONS = {}
|
||||
USER_FEEDBACK_LOG = []
|
||||
|
||||
# OPRO session storage (for system instruction optimization)
|
||||
OPRO_RUNS = {}
|
||||
OPRO_RUN_LOG = []
|
||||
|
||||
def create_session(query: str) -> str:
|
||||
sid = uuid.uuid4().hex
|
||||
SESSIONS[sid] = {
|
||||
@@ -54,3 +60,234 @@ def set_session_model(sid: str, model_name: str | None):
|
||||
s = SESSIONS.get(sid)
|
||||
if s is not None:
|
||||
s["model_name"] = model_name
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# TRUE OPRO SESSION MANAGEMENT
|
||||
# ============================================================================
|
||||
|
||||
# Session storage (contains multiple runs)
|
||||
OPRO_SESSIONS = {}
|
||||
|
||||
def create_opro_session(session_name: str = None) -> str:
|
||||
"""
|
||||
Create a new OPRO session that can contain multiple runs.
|
||||
|
||||
Args:
|
||||
session_name: Optional name for the session
|
||||
|
||||
Returns:
|
||||
session_id: Unique identifier for this session
|
||||
"""
|
||||
session_id = uuid.uuid4().hex
|
||||
OPRO_SESSIONS[session_id] = {
|
||||
"session_name": session_name or "新会话", # Will be updated with first task description
|
||||
"created_at": uuid.uuid1().time,
|
||||
"run_ids": [], # List of run IDs in this session
|
||||
"chat_history": [] # Cross-run chat history
|
||||
}
|
||||
return session_id
|
||||
|
||||
|
||||
def get_opro_session(session_id: str) -> Dict[str, Any]:
|
||||
"""Get OPRO session by ID."""
|
||||
return OPRO_SESSIONS.get(session_id)
|
||||
|
||||
|
||||
def list_opro_sessions() -> List[Dict[str, Any]]:
|
||||
"""
|
||||
List all OPRO sessions with summary information.
|
||||
|
||||
Returns:
|
||||
List of session summaries
|
||||
"""
|
||||
return [
|
||||
{
|
||||
"session_id": session_id,
|
||||
"session_name": session["session_name"],
|
||||
"num_runs": len(session["run_ids"]),
|
||||
"created_at": session["created_at"]
|
||||
}
|
||||
for session_id, session in OPRO_SESSIONS.items()
|
||||
]
|
||||
|
||||
|
||||
def create_opro_run(
|
||||
task_description: str,
|
||||
test_cases: List[Tuple[str, str]] = None,
|
||||
model_name: str = None,
|
||||
session_id: str = None
|
||||
) -> str:
|
||||
"""
|
||||
Create a new OPRO optimization run.
|
||||
|
||||
Args:
|
||||
task_description: Description of the task to optimize for
|
||||
test_cases: List of (input, expected_output) tuples for evaluation
|
||||
model_name: Optional model name to use
|
||||
session_id: Optional session ID to associate this run with
|
||||
|
||||
Returns:
|
||||
run_id: Unique identifier for this OPRO run
|
||||
"""
|
||||
run_id = uuid.uuid4().hex
|
||||
OPRO_RUNS[run_id] = {
|
||||
"task_description": task_description,
|
||||
"test_cases": test_cases or [],
|
||||
"model_name": model_name,
|
||||
"session_id": session_id, # Link to parent session
|
||||
"iteration": 0,
|
||||
"trajectory": [], # List of (instruction, score) tuples
|
||||
"best_instruction": None,
|
||||
"best_score": 0.0,
|
||||
"current_candidates": [],
|
||||
"created_at": uuid.uuid1().time,
|
||||
"status": "active" # active, completed, failed
|
||||
}
|
||||
|
||||
# Add run to session if session_id provided
|
||||
if session_id and session_id in OPRO_SESSIONS:
|
||||
OPRO_SESSIONS[session_id]["run_ids"].append(run_id)
|
||||
# Update session name with first task description if it's still default
|
||||
if OPRO_SESSIONS[session_id]["session_name"] == "新会话" and len(OPRO_SESSIONS[session_id]["run_ids"]) == 1:
|
||||
OPRO_SESSIONS[session_id]["session_name"] = task_description
|
||||
|
||||
return run_id
|
||||
|
||||
|
||||
def get_opro_run(run_id: str) -> Dict[str, Any]:
|
||||
"""Get OPRO run by ID."""
|
||||
return OPRO_RUNS.get(run_id)
|
||||
|
||||
|
||||
def update_opro_iteration(
|
||||
run_id: str,
|
||||
candidates: List[str],
|
||||
scores: List[float] = None
|
||||
):
|
||||
"""
|
||||
Update OPRO run with new iteration results.
|
||||
|
||||
Args:
|
||||
run_id: OPRO run identifier
|
||||
candidates: List of system instruction candidates
|
||||
scores: Optional list of scores (if evaluated)
|
||||
"""
|
||||
run = OPRO_RUNS.get(run_id)
|
||||
if not run:
|
||||
return
|
||||
|
||||
run["iteration"] += 1
|
||||
run["current_candidates"] = candidates
|
||||
|
||||
# If scores provided, update trajectory
|
||||
if scores and len(scores) == len(candidates):
|
||||
for candidate, score in zip(candidates, scores):
|
||||
run["trajectory"].append((candidate, score))
|
||||
|
||||
# Update best if this is better
|
||||
if score > run["best_score"]:
|
||||
run["best_score"] = score
|
||||
run["best_instruction"] = candidate
|
||||
|
||||
# Log the iteration
|
||||
OPRO_RUN_LOG.append({
|
||||
"run_id": run_id,
|
||||
"iteration": run["iteration"],
|
||||
"num_candidates": len(candidates),
|
||||
"best_score": run["best_score"]
|
||||
})
|
||||
|
||||
|
||||
def add_opro_evaluation(
|
||||
run_id: str,
|
||||
instruction: str,
|
||||
score: float
|
||||
):
|
||||
"""
|
||||
Add a single evaluation result to OPRO run.
|
||||
|
||||
Args:
|
||||
run_id: OPRO run identifier
|
||||
instruction: System instruction that was evaluated
|
||||
score: Performance score
|
||||
"""
|
||||
run = OPRO_RUNS.get(run_id)
|
||||
if not run:
|
||||
return
|
||||
|
||||
# Add to trajectory
|
||||
run["trajectory"].append((instruction, score))
|
||||
|
||||
# Update best if this is better
|
||||
if score > run["best_score"]:
|
||||
run["best_score"] = score
|
||||
run["best_instruction"] = instruction
|
||||
|
||||
|
||||
def get_opro_trajectory(run_id: str) -> List[Tuple[str, float]]:
|
||||
"""
|
||||
Get the performance trajectory for an OPRO run.
|
||||
|
||||
Returns:
|
||||
List of (instruction, score) tuples sorted by score (highest first)
|
||||
"""
|
||||
run = OPRO_RUNS.get(run_id)
|
||||
if not run:
|
||||
return []
|
||||
|
||||
trajectory = run["trajectory"]
|
||||
return sorted(trajectory, key=lambda x: x[1], reverse=True)
|
||||
|
||||
|
||||
def set_opro_test_cases(
|
||||
run_id: str,
|
||||
test_cases: List[Tuple[str, str]]
|
||||
):
|
||||
"""
|
||||
Set or update test cases for an OPRO run.
|
||||
|
||||
Args:
|
||||
run_id: OPRO run identifier
|
||||
test_cases: List of (input, expected_output) tuples
|
||||
"""
|
||||
run = OPRO_RUNS.get(run_id)
|
||||
if run:
|
||||
run["test_cases"] = test_cases
|
||||
|
||||
|
||||
def complete_opro_run(run_id: str):
|
||||
"""Mark an OPRO run as completed."""
|
||||
run = OPRO_RUNS.get(run_id)
|
||||
if run:
|
||||
run["status"] = "completed"
|
||||
|
||||
|
||||
def list_opro_runs(session_id: str = None) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
List all OPRO runs with summary information.
|
||||
|
||||
Args:
|
||||
session_id: Optional session ID to filter runs by session
|
||||
|
||||
Returns:
|
||||
List of run summaries
|
||||
"""
|
||||
runs_to_list = OPRO_RUNS.items()
|
||||
|
||||
# Filter by session if provided
|
||||
if session_id:
|
||||
runs_to_list = [(rid, r) for rid, r in runs_to_list if r.get("session_id") == session_id]
|
||||
|
||||
return [
|
||||
{
|
||||
"run_id": run_id,
|
||||
"task_description": run["task_description"][:100] + "..." if len(run["task_description"]) > 100 else run["task_description"],
|
||||
"iteration": run["iteration"],
|
||||
"best_score": run["best_score"],
|
||||
"num_test_cases": len(run["test_cases"]),
|
||||
"status": run["status"],
|
||||
"session_id": run.get("session_id")
|
||||
}
|
||||
for run_id, run in runs_to_list
|
||||
]
|
||||
|
||||
@@ -1,12 +1,19 @@
|
||||
import re
|
||||
import numpy as np
|
||||
from typing import List, Tuple
|
||||
from sklearn.cluster import AgglomerativeClustering
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
import config
|
||||
|
||||
from .ollama_client import call_qwen
|
||||
from .xinference_client import embed_texts
|
||||
from .prompt_utils import refine_instruction, refine_instruction_with_history
|
||||
from .prompt_utils import (
|
||||
refine_instruction,
|
||||
refine_instruction_with_history,
|
||||
generate_initial_system_instruction_candidates,
|
||||
generate_optimized_system_instruction,
|
||||
refine_based_on_selection
|
||||
)
|
||||
|
||||
def parse_candidates(raw: str) -> list:
|
||||
lines = [l.strip() for l in re.split(r'\r?\n', raw) if l.strip()]
|
||||
@@ -33,7 +40,7 @@ def cluster_and_select(candidates: list, top_k=config.TOP_K, distance_threshold=
|
||||
linkage="average")
|
||||
labels = clustering.fit_predict(X)
|
||||
|
||||
selected_idx = []
|
||||
selected_idx = []
|
||||
for label in sorted(set(labels)):
|
||||
idxs = [i for i,l in enumerate(labels) if l == label]
|
||||
sims = cosine_similarity(X[idxs]).mean(axis=1)
|
||||
@@ -44,6 +51,10 @@ def cluster_and_select(candidates: list, top_k=config.TOP_K, distance_threshold=
|
||||
return selected[:top_k]
|
||||
|
||||
def generate_candidates(query: str, rejected=None, top_k=config.TOP_K, model_name=None):
|
||||
"""
|
||||
LEGACY: Query rewriting function (NOT true OPRO).
|
||||
Kept for backward compatibility with existing API endpoints.
|
||||
"""
|
||||
rejected = rejected or []
|
||||
if rejected:
|
||||
prompt = refine_instruction_with_history(query, rejected)
|
||||
@@ -53,3 +64,130 @@ def generate_candidates(query: str, rejected=None, top_k=config.TOP_K, model_nam
|
||||
raw = call_qwen(prompt, temperature=0.9, max_tokens=1024, model_name=model_name)
|
||||
all_candidates = parse_candidates(raw)
|
||||
return cluster_and_select(all_candidates, top_k=top_k)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# TRUE OPRO FUNCTIONS (System Instruction Optimization)
|
||||
# ============================================================================
|
||||
|
||||
def generate_system_instruction_candidates(
|
||||
task_description: str,
|
||||
trajectory: List[Tuple[str, float]] = None,
|
||||
top_k: int = config.TOP_K,
|
||||
pool_size: int = None,
|
||||
model_name: str = None
|
||||
) -> List[str]:
|
||||
"""
|
||||
TRUE OPRO: Generates optimized system instruction candidates.
|
||||
|
||||
This is the core OPRO function that generates system instructions based on
|
||||
performance trajectory (if available) or initial candidates (if starting fresh).
|
||||
|
||||
Args:
|
||||
task_description: Description of the task the LLM should perform
|
||||
trajectory: Optional list of (instruction, score) tuples from previous iterations
|
||||
top_k: Number of diverse candidates to return (default: config.TOP_K = 5)
|
||||
pool_size: Number of candidates to generate before clustering (default: config.GENERATION_POOL_SIZE = 10)
|
||||
model_name: Optional model name to use for generation
|
||||
|
||||
Returns:
|
||||
List of top-k diverse system instruction candidates
|
||||
"""
|
||||
pool_size = pool_size or config.GENERATION_POOL_SIZE
|
||||
|
||||
# Generate the meta-prompt based on whether we have trajectory data
|
||||
if trajectory and len(trajectory) > 0:
|
||||
# Sort trajectory by score (highest first)
|
||||
sorted_trajectory = sorted(trajectory, key=lambda x: x[1], reverse=True)
|
||||
meta_prompt = generate_optimized_system_instruction(task_description, sorted_trajectory, pool_size)
|
||||
else:
|
||||
# No trajectory yet, generate initial candidates
|
||||
meta_prompt = generate_initial_system_instruction_candidates(task_description, pool_size)
|
||||
|
||||
# Use the optimizer LLM to generate candidates
|
||||
raw = call_qwen(meta_prompt, temperature=0.9, max_tokens=1024, model_name=model_name)
|
||||
|
||||
# Parse the generated candidates
|
||||
all_candidates = parse_candidates(raw)
|
||||
|
||||
# Cluster and select diverse representatives
|
||||
return cluster_and_select(all_candidates, top_k=top_k)
|
||||
|
||||
|
||||
def evaluate_system_instruction(
|
||||
system_instruction: str,
|
||||
test_cases: List[Tuple[str, str]],
|
||||
model_name: str = None
|
||||
) -> float:
|
||||
"""
|
||||
TRUE OPRO: Evaluates a system instruction's performance on test cases.
|
||||
|
||||
Args:
|
||||
system_instruction: The system instruction to evaluate
|
||||
test_cases: List of (input, expected_output) tuples
|
||||
model_name: Optional model name to use for evaluation
|
||||
|
||||
Returns:
|
||||
Performance score (0.0 to 1.0)
|
||||
"""
|
||||
if not test_cases:
|
||||
return 0.0
|
||||
|
||||
correct = 0
|
||||
total = len(test_cases)
|
||||
|
||||
for input_text, expected_output in test_cases:
|
||||
# Construct the full prompt with system instruction
|
||||
full_prompt = f"{system_instruction}\n\n{input_text}"
|
||||
|
||||
# Get LLM response
|
||||
response = call_qwen(full_prompt, temperature=0.2, max_tokens=512, model_name=model_name)
|
||||
|
||||
# Simple exact match scoring (can be replaced with more sophisticated metrics)
|
||||
if expected_output.strip().lower() in response.strip().lower():
|
||||
correct += 1
|
||||
|
||||
return correct / total
|
||||
|
||||
|
||||
def refine_instruction_candidates(
|
||||
task_description: str,
|
||||
selected_instruction: str,
|
||||
rejected_instructions: List[str],
|
||||
top_k: int = config.TOP_K,
|
||||
pool_size: int = None,
|
||||
model_name: str = None
|
||||
) -> List[str]:
|
||||
"""
|
||||
Simple refinement: Generate new candidates based on user's selection.
|
||||
|
||||
This is NOT OPRO - just straightforward iterative refinement.
|
||||
User picks a favorite, we generate variations of it while avoiding rejected ones.
|
||||
|
||||
Args:
|
||||
task_description: Description of the task
|
||||
selected_instruction: The instruction the user selected
|
||||
rejected_instructions: The instructions the user didn't select
|
||||
top_k: Number of diverse candidates to return
|
||||
pool_size: Number of candidates to generate before clustering
|
||||
model_name: Optional model name to use
|
||||
|
||||
Returns:
|
||||
List of refined instruction candidates
|
||||
"""
|
||||
pool_size = pool_size or config.GENERATION_POOL_SIZE
|
||||
|
||||
# Generate the refinement prompt
|
||||
meta_prompt = refine_based_on_selection(
|
||||
task_description,
|
||||
selected_instruction,
|
||||
rejected_instructions,
|
||||
pool_size
|
||||
)
|
||||
|
||||
# Use LLM to generate refined candidates
|
||||
raw = call_qwen(meta_prompt, temperature=0.9, max_tokens=1024, model_name=model_name)
|
||||
|
||||
# Parse and cluster
|
||||
all_candidates = parse_candidates(raw)
|
||||
return cluster_and_select(all_candidates, top_k=top_k)
|
||||
|
||||
141
build-8b.sh
Executable file
141
build-8b.sh
Executable file
@@ -0,0 +1,141 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Quick build script for qwen3:8b (lower memory usage)
|
||||
# Use this if your server has less than 12GB RAM
|
||||
|
||||
set -e
|
||||
|
||||
echo "=========================================="
|
||||
echo "Building with qwen3:8b (Lower Memory)"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
echo "Memory requirements:"
|
||||
echo " - qwen3:8b: ~5GB RAM"
|
||||
echo " - qwen3:14b: ~10GB RAM"
|
||||
echo ""
|
||||
|
||||
# Check if 8b model is available
|
||||
if ! ollama list | grep -q "qwen3:8b"; then
|
||||
echo "ERROR: qwen3:8b model not found!"
|
||||
echo ""
|
||||
echo "Please download it first:"
|
||||
echo " ollama pull qwen3:8b"
|
||||
echo ""
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Clean up
|
||||
echo "Cleaning up previous builds..."
|
||||
rm -rf ollama-models/
|
||||
docker rmi system-prompt-optimizer:allinone 2>/dev/null || true
|
||||
|
||||
# Export 8b model
|
||||
echo ""
|
||||
echo "Exporting qwen3:8b model..."
|
||||
mkdir -p ollama-models/models/{manifests/registry.ollama.ai/library,blobs}
|
||||
|
||||
# Function to get blob hashes from manifest
|
||||
get_blobs_from_manifest() {
|
||||
local manifest_file=$1
|
||||
grep -o 'sha256:[a-f0-9]\{64\}' "$manifest_file" | sed 's/sha256://' | sort -u
|
||||
}
|
||||
|
||||
# Function to copy model files
|
||||
copy_model() {
|
||||
local model_name=$1
|
||||
local model_tag=$2
|
||||
local manifest_dir="$HOME/.ollama/models/manifests/registry.ollama.ai/library/$model_name"
|
||||
|
||||
if [ ! -d "$manifest_dir" ]; then
|
||||
echo "ERROR: Model manifest not found: $manifest_dir"
|
||||
return 1
|
||||
fi
|
||||
|
||||
echo " Copying $model_name:$model_tag manifest..."
|
||||
mkdir -p "ollama-models/models/manifests/registry.ollama.ai/library/$model_name"
|
||||
|
||||
if [ -f "$manifest_dir/$model_tag" ]; then
|
||||
cp "$manifest_dir/$model_tag" "ollama-models/models/manifests/registry.ollama.ai/library/$model_name/"
|
||||
|
||||
echo " Finding blob files for $model_name:$model_tag..."
|
||||
local blob_hashes=$(get_blobs_from_manifest "$manifest_dir/$model_tag")
|
||||
local blob_count=0
|
||||
|
||||
for blob_hash in $blob_hashes; do
|
||||
local blob_file="$HOME/.ollama/models/blobs/sha256-$blob_hash"
|
||||
if [ -f "$blob_file" ]; then
|
||||
cp "$blob_file" "ollama-models/models/blobs/" 2>/dev/null
|
||||
blob_count=$((blob_count + 1))
|
||||
fi
|
||||
done
|
||||
|
||||
echo " ✓ $model_name:$model_tag copied ($blob_count blobs)"
|
||||
else
|
||||
echo "ERROR: Manifest file not found: $manifest_dir/$model_tag"
|
||||
return 1
|
||||
fi
|
||||
}
|
||||
|
||||
# Copy models
|
||||
copy_model "qwen3" "8b" || exit 1
|
||||
copy_model "qwen3-embedding" "4b" || exit 1
|
||||
|
||||
echo ""
|
||||
echo "✓ Models exported successfully"
|
||||
echo ""
|
||||
|
||||
# Update config.py to use 8b
|
||||
echo "Updating config.py to use qwen3:8b..."
|
||||
sed -i.bak 's/DEFAULT_CHAT_MODEL = "qwen3:14b"/DEFAULT_CHAT_MODEL = "qwen3:8b"/' config.py
|
||||
|
||||
# Update docker-entrypoint.sh to check for 8b
|
||||
echo "Updating docker-entrypoint.sh to check for qwen3:8b..."
|
||||
sed -i.bak 's/qwen3:14b/qwen3:8b/g' docker-entrypoint.sh
|
||||
|
||||
# Build image
|
||||
echo ""
|
||||
echo "Building Docker image..."
|
||||
docker build --platform linux/amd64 \
|
||||
-f Dockerfile.allinone \
|
||||
-t system-prompt-optimizer:allinone .
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo ""
|
||||
echo "Build failed!"
|
||||
# Restore backups
|
||||
mv config.py.bak config.py
|
||||
mv docker-entrypoint.sh.bak docker-entrypoint.sh
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Export image
|
||||
echo ""
|
||||
echo "Exporting Docker image..."
|
||||
docker save -o system-prompt-optimizer-allinone.tar system-prompt-optimizer:allinone
|
||||
|
||||
# Restore original files
|
||||
mv config.py.bak config.py
|
||||
mv docker-entrypoint.sh.bak docker-entrypoint.sh
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo "Build Complete!"
|
||||
echo "=========================================="
|
||||
ls -lh system-prompt-optimizer-allinone.tar
|
||||
echo ""
|
||||
echo "This image uses qwen3:8b (~5GB RAM required)"
|
||||
echo ""
|
||||
echo "Transfer to server and run:"
|
||||
echo ""
|
||||
echo " CPU mode:"
|
||||
echo " docker load -i system-prompt-optimizer-allinone.tar"
|
||||
echo " docker run -d -p 8010:8010 --restart unless-stopped system-prompt-optimizer:allinone"
|
||||
echo ""
|
||||
echo " GPU mode (recommended):"
|
||||
echo " docker load -i system-prompt-optimizer-allinone.tar"
|
||||
echo " docker run -d --gpus all -p 8010:8010 --restart unless-stopped system-prompt-optimizer:allinone"
|
||||
echo ""
|
||||
echo "Note: GPU mode provides 5-10x faster inference."
|
||||
echo " See GPU_DEPLOYMENT.md for GPU setup instructions."
|
||||
echo ""
|
||||
|
||||
133
build-allinone.sh
Executable file
133
build-allinone.sh
Executable file
@@ -0,0 +1,133 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Build all-in-one Docker image with Ollama and models
|
||||
# This creates a complete offline-deployable image
|
||||
|
||||
set -e
|
||||
|
||||
IMAGE_NAME="system-prompt-optimizer"
|
||||
IMAGE_TAG="allinone"
|
||||
EXPORT_FILE="${IMAGE_NAME}-${IMAGE_TAG}.tar"
|
||||
|
||||
echo "=========================================="
|
||||
echo "Building All-in-One Docker Image"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
echo "This will create a Docker image containing:"
|
||||
echo " - Python application"
|
||||
echo " - Ollama service (v0.13.1)"
|
||||
echo " - qwen3:14b model"
|
||||
echo " - qwen3-embedding:4b model"
|
||||
echo ""
|
||||
echo "Target platform: linux/amd64 (x86_64)"
|
||||
echo ""
|
||||
echo "WARNING: The final image will be 10-20GB in size!"
|
||||
echo ""
|
||||
echo "NOTE: If you're building on Apple Silicon (M1/M2/M3),"
|
||||
echo " Docker will use emulation which may be slower."
|
||||
echo " The image will still work on x86_64 servers."
|
||||
echo ""
|
||||
|
||||
# Check if ollama-models directory exists
|
||||
if [ ! -d "ollama-models" ]; then
|
||||
echo "ERROR: ollama-models directory not found!"
|
||||
echo ""
|
||||
echo "Please run ./export-ollama-models.sh first to export the models."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "✓ Found ollama-models directory"
|
||||
echo ""
|
||||
|
||||
# Check if Ollama binary exists
|
||||
if [ ! -f "ollama-linux-amd64.tgz" ]; then
|
||||
echo "ERROR: ollama-linux-amd64.tgz not found!"
|
||||
echo ""
|
||||
echo "Please download it first:"
|
||||
echo " curl -L -o ollama-linux-amd64.tgz https://github.com/ollama/ollama/releases/download/v0.13.1/ollama-linux-amd64.tgz"
|
||||
echo ""
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "✓ Found ollama-linux-amd64.tgz"
|
||||
echo ""
|
||||
|
||||
# Check disk space
|
||||
AVAILABLE_SPACE=$(df -h . | awk 'NR==2 {print $4}')
|
||||
echo "Available disk space: $AVAILABLE_SPACE"
|
||||
echo "Required: ~20GB for build process"
|
||||
echo ""
|
||||
|
||||
read -p "Continue with build? (y/n) " -n 1 -r
|
||||
echo
|
||||
if [[ ! $REPLY =~ ^[Yy]$ ]]; then
|
||||
echo "Build cancelled."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo "Building Docker image..."
|
||||
echo "=========================================="
|
||||
echo "Platform: linux/amd64 (x86_64)"
|
||||
echo "This may take 20-40 minutes depending on your machine..."
|
||||
echo ""
|
||||
|
||||
# Build for amd64 platform explicitly
|
||||
docker build --platform linux/amd64 -f Dockerfile.allinone -t ${IMAGE_NAME}:${IMAGE_TAG} .
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo "Build complete!"
|
||||
echo "=========================================="
|
||||
docker images | grep ${IMAGE_NAME}
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo "Exporting image to ${EXPORT_FILE}..."
|
||||
echo "=========================================="
|
||||
echo "This will take several minutes..."
|
||||
docker save -o ${EXPORT_FILE} ${IMAGE_NAME}:${IMAGE_TAG}
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo "Export complete!"
|
||||
echo "=========================================="
|
||||
ls -lh ${EXPORT_FILE}
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo "Deployment Instructions"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
echo "1. Transfer ${EXPORT_FILE} to target server:"
|
||||
echo " scp ${EXPORT_FILE} user@server:/path/"
|
||||
echo ""
|
||||
echo "2. On target server, load the image:"
|
||||
echo " docker load -i ${EXPORT_FILE}"
|
||||
echo ""
|
||||
echo "3. Run the container:"
|
||||
echo ""
|
||||
echo " CPU mode:"
|
||||
echo " docker run -d \\"
|
||||
echo " --name system-prompt-optimizer \\"
|
||||
echo " -p 8010:8010 \\"
|
||||
echo " --restart unless-stopped \\"
|
||||
echo " ${IMAGE_NAME}:${IMAGE_TAG}"
|
||||
echo ""
|
||||
echo " GPU mode (recommended if NVIDIA GPU available):"
|
||||
echo " docker run -d \\"
|
||||
echo " --name system-prompt-optimizer \\"
|
||||
echo " --gpus all \\"
|
||||
echo " -p 8010:8010 \\"
|
||||
echo " --restart unless-stopped \\"
|
||||
echo " ${IMAGE_NAME}:${IMAGE_TAG}"
|
||||
echo ""
|
||||
echo " Note: Port 11434 (Ollama) is optional and only needed for debugging."
|
||||
echo " GPU mode provides 5-10x faster inference. See GPU_DEPLOYMENT.md for details."
|
||||
echo ""
|
||||
echo "4. Access the application:"
|
||||
echo " http://<server-ip>:8010/ui/opro.html"
|
||||
echo ""
|
||||
echo "See DEPLOYMENT.md for more details."
|
||||
|
||||
37
build-and-export.sh
Executable file
37
build-and-export.sh
Executable file
@@ -0,0 +1,37 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Build and export Docker image for offline deployment
|
||||
# Usage: ./build-and-export.sh
|
||||
|
||||
set -e
|
||||
|
||||
IMAGE_NAME="system-prompt-optimizer"
|
||||
IMAGE_TAG="latest"
|
||||
EXPORT_FILE="${IMAGE_NAME}.tar"
|
||||
|
||||
echo "=========================================="
|
||||
echo "Building Docker image..."
|
||||
echo "=========================================="
|
||||
docker build -t ${IMAGE_NAME}:${IMAGE_TAG} .
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo "Exporting Docker image to ${EXPORT_FILE}..."
|
||||
echo "=========================================="
|
||||
docker save -o ${EXPORT_FILE} ${IMAGE_NAME}:${IMAGE_TAG}
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo "Export complete!"
|
||||
echo "=========================================="
|
||||
ls -lh ${EXPORT_FILE}
|
||||
|
||||
echo ""
|
||||
echo "Next steps:"
|
||||
echo "1. Transfer ${EXPORT_FILE} to target server"
|
||||
echo "2. Transfer docker-compose.yml to target server (optional)"
|
||||
echo "3. On target server, run: docker load -i ${EXPORT_FILE}"
|
||||
echo "4. On target server, run: docker-compose up -d"
|
||||
echo ""
|
||||
echo "See DEPLOYMENT.md for detailed instructions."
|
||||
|
||||
@@ -7,13 +7,14 @@ APP_CONTACT = {"name": "OPRO Team", "url": "http://127.0.0.1:8010/ui/"}
|
||||
OLLAMA_HOST = "http://127.0.0.1:11434"
|
||||
OLLAMA_GENERATE_URL = f"{OLLAMA_HOST}/api/generate"
|
||||
OLLAMA_TAGS_URL = f"{OLLAMA_HOST}/api/tags"
|
||||
DEFAULT_CHAT_MODEL = "qwen3:8b"
|
||||
DEFAULT_CHAT_MODEL = "qwen3:14b"
|
||||
DEFAULT_EMBED_MODEL = "qwen3-embedding:4b"
|
||||
|
||||
# Xinference
|
||||
XINFERENCE_EMBED_URL = "http://127.0.0.1:9997/models/bge-base-zh/embed"
|
||||
|
||||
# Clustering/selection
|
||||
TOP_K = 5
|
||||
GENERATION_POOL_SIZE = 10 # Generate this many candidates before clustering
|
||||
TOP_K = 5 # Return this many diverse candidates to user
|
||||
CLUSTER_DISTANCE_THRESHOLD = 0.15
|
||||
|
||||
|
||||
23
docker-compose.yml
Normal file
23
docker-compose.yml
Normal file
@@ -0,0 +1,23 @@
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
app:
|
||||
build: .
|
||||
container_name: system-prompt-optimizer
|
||||
ports:
|
||||
- "8010:8010"
|
||||
environment:
|
||||
- OLLAMA_HOST=http://host.docker.internal:11434
|
||||
- PYTHONUNBUFFERED=1
|
||||
volumes:
|
||||
- ./outputs:/app/outputs
|
||||
restart: unless-stopped
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8010/health"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 5s
|
||||
|
||||
103
docker-entrypoint.sh
Normal file
103
docker-entrypoint.sh
Normal file
@@ -0,0 +1,103 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
echo "=========================================="
|
||||
echo "System Prompt Optimizer - Starting Up"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
|
||||
# Check if Ollama binary exists
|
||||
if ! command -v ollama &> /dev/null; then
|
||||
echo "ERROR: Ollama binary not found!"
|
||||
echo "Expected location: /usr/bin/ollama or /usr/local/bin/ollama"
|
||||
ls -la /usr/bin/ollama* 2>/dev/null || echo "No ollama in /usr/bin/"
|
||||
ls -la /usr/local/bin/ollama* 2>/dev/null || echo "No ollama in /usr/local/bin/"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "✓ Ollama binary found: $(which ollama)"
|
||||
echo ""
|
||||
|
||||
# Check if model files exist
|
||||
echo "Checking model files..."
|
||||
if [ ! -d "/root/.ollama/models" ]; then
|
||||
echo "ERROR: /root/.ollama/models directory not found!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
MANIFEST_COUNT=$(find /root/.ollama/models/manifests -type f 2>/dev/null | wc -l)
|
||||
BLOB_COUNT=$(find /root/.ollama/models/blobs -type f 2>/dev/null | wc -l)
|
||||
|
||||
echo "✓ Found $MANIFEST_COUNT manifest files"
|
||||
echo "✓ Found $BLOB_COUNT blob files"
|
||||
|
||||
if [ "$BLOB_COUNT" -lt 10 ]; then
|
||||
echo "WARNING: Very few blob files found. Models may not be complete."
|
||||
fi
|
||||
echo ""
|
||||
|
||||
echo "Starting Ollama service..."
|
||||
ollama serve > /tmp/ollama.log 2>&1 &
|
||||
OLLAMA_PID=$!
|
||||
|
||||
# Wait for Ollama to be ready
|
||||
echo "Waiting for Ollama to start..."
|
||||
OLLAMA_READY=false
|
||||
for i in {1..60}; do
|
||||
if curl -s http://localhost:11434/api/tags > /dev/null 2>&1; then
|
||||
echo "Ollama is ready!"
|
||||
OLLAMA_READY=true
|
||||
break
|
||||
fi
|
||||
echo "Waiting for Ollama... ($i/60)"
|
||||
sleep 3
|
||||
done
|
||||
|
||||
if [ "$OLLAMA_READY" = false ]; then
|
||||
echo ""
|
||||
echo "ERROR: Ollama failed to start within 3 minutes!"
|
||||
echo ""
|
||||
echo "Ollama logs:"
|
||||
cat /tmp/ollama.log
|
||||
echo ""
|
||||
echo "Check full logs with: docker logs system-prompt-optimizer"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Check if models exist, if not, show warning
|
||||
echo ""
|
||||
echo "Checking for models..."
|
||||
ollama list
|
||||
|
||||
echo ""
|
||||
if ! ollama list | grep -q "qwen3:14b"; then
|
||||
echo "ERROR: qwen3:14b model not found!"
|
||||
echo "The application requires qwen3:14b to function properly."
|
||||
echo ""
|
||||
echo "Available models:"
|
||||
ollama list
|
||||
echo ""
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! ollama list | grep -q "qwen3-embedding"; then
|
||||
echo "WARNING: qwen3-embedding model not found!"
|
||||
echo "The application requires qwen3-embedding:4b for embeddings."
|
||||
echo "Continuing anyway, but embeddings may not work."
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "✓ All required models are available"
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo "Starting FastAPI application..."
|
||||
echo "=========================================="
|
||||
echo "Application will be available at:"
|
||||
echo " - Web UI: http://localhost:8010/ui/opro.html"
|
||||
echo " - API Docs: http://localhost:8010/docs"
|
||||
echo " - Ollama: http://localhost:11434"
|
||||
echo ""
|
||||
|
||||
exec uvicorn _qwen_xinference_demo.api:app --host 0.0.0.0 --port 8010
|
||||
|
||||
164
examples/opro_demo.py
Normal file
164
examples/opro_demo.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""
|
||||
TRUE OPRO Demo Script
|
||||
|
||||
This script demonstrates the true OPRO (Optimization by PROmpting) functionality.
|
||||
It shows how to:
|
||||
1. Generate initial system instruction candidates
|
||||
2. Evaluate them on test cases
|
||||
3. Use the performance trajectory to generate better candidates
|
||||
"""
|
||||
|
||||
import sys
|
||||
sys.path.insert(0, '.')
|
||||
|
||||
from _qwen_xinference_demo.opro.user_prompt_optimizer import (
|
||||
generate_system_instruction_candidates,
|
||||
evaluate_system_instruction
|
||||
)
|
||||
import config
|
||||
|
||||
|
||||
def demo_opro_workflow():
|
||||
"""
|
||||
Demonstrates a complete OPRO optimization workflow.
|
||||
"""
|
||||
print("=" * 80)
|
||||
print("TRUE OPRO Demo - System Instruction Optimization")
|
||||
print("=" * 80)
|
||||
print(f"Pool Size: {config.GENERATION_POOL_SIZE} candidates → Clustered to Top {config.TOP_K}")
|
||||
|
||||
# Define the task
|
||||
task_description = """
|
||||
任务:将用户输入的中文句子翻译成英文。
|
||||
要求:翻译准确、自然、符合英语表达习惯。
|
||||
"""
|
||||
|
||||
print(f"\n📋 Task Description:\n{task_description}")
|
||||
|
||||
# Define test cases for evaluation
|
||||
test_cases = [
|
||||
("你好,很高兴见到你", "Hello, nice to meet you"),
|
||||
("今天天气真好", "The weather is really nice today"),
|
||||
("我喜欢学习编程", "I like learning programming"),
|
||||
("这本书很有趣", "This book is very interesting"),
|
||||
]
|
||||
|
||||
print(f"\n🧪 Test Cases: {len(test_cases)} examples")
|
||||
for i, (input_text, expected) in enumerate(test_cases, 1):
|
||||
print(f" {i}. '{input_text}' → '{expected}'")
|
||||
|
||||
# Iteration 1: Generate initial candidates
|
||||
print("\n" + "=" * 80)
|
||||
print("🔄 Iteration 1: Generating Initial System Instruction Candidates")
|
||||
print("=" * 80)
|
||||
|
||||
print("\n⏳ Generating candidates... (this may take a moment)")
|
||||
candidates_round1 = generate_system_instruction_candidates(
|
||||
task_description=task_description,
|
||||
trajectory=None, # No history yet
|
||||
top_k=3,
|
||||
model_name=None # Use default model
|
||||
)
|
||||
|
||||
print(f"\n✅ Generated {len(candidates_round1)} candidates:")
|
||||
for i, candidate in enumerate(candidates_round1, 1):
|
||||
print(f"\n Candidate {i}:")
|
||||
print(f" {candidate[:100]}..." if len(candidate) > 100 else f" {candidate}")
|
||||
|
||||
# Evaluate each candidate
|
||||
print("\n" + "-" * 80)
|
||||
print("📊 Evaluating Candidates on Test Cases")
|
||||
print("-" * 80)
|
||||
|
||||
trajectory = []
|
||||
for i, candidate in enumerate(candidates_round1, 1):
|
||||
print(f"\n⏳ Evaluating Candidate {i}...")
|
||||
score = evaluate_system_instruction(
|
||||
system_instruction=candidate,
|
||||
test_cases=test_cases,
|
||||
model_name=None
|
||||
)
|
||||
trajectory.append((candidate, score))
|
||||
print(f" Score: {score:.2%}")
|
||||
|
||||
# Sort by score
|
||||
trajectory.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
print("\n📈 Performance Summary (Round 1):")
|
||||
for i, (candidate, score) in enumerate(trajectory, 1):
|
||||
print(f" {i}. Score: {score:.2%} - {candidate[:60]}...")
|
||||
|
||||
best_score = trajectory[0][1]
|
||||
print(f"\n🏆 Best Score: {best_score:.2%}")
|
||||
|
||||
# Iteration 2: Generate optimized candidates based on trajectory
|
||||
print("\n" + "=" * 80)
|
||||
print("🔄 Iteration 2: Generating Optimized System Instructions")
|
||||
print("=" * 80)
|
||||
print(f"\n💡 Using performance trajectory to generate better candidates...")
|
||||
print(f" Goal: Beat current best score of {best_score:.2%}")
|
||||
|
||||
print("\n⏳ Generating optimized candidates...")
|
||||
candidates_round2 = generate_system_instruction_candidates(
|
||||
task_description=task_description,
|
||||
trajectory=trajectory, # Use performance history
|
||||
top_k=3,
|
||||
model_name=None
|
||||
)
|
||||
|
||||
print(f"\n✅ Generated {len(candidates_round2)} optimized candidates:")
|
||||
for i, candidate in enumerate(candidates_round2, 1):
|
||||
print(f"\n Candidate {i}:")
|
||||
print(f" {candidate[:100]}..." if len(candidate) > 100 else f" {candidate}")
|
||||
|
||||
# Evaluate new candidates
|
||||
print("\n" + "-" * 80)
|
||||
print("📊 Evaluating Optimized Candidates")
|
||||
print("-" * 80)
|
||||
|
||||
for i, candidate in enumerate(candidates_round2, 1):
|
||||
print(f"\n⏳ Evaluating Optimized Candidate {i}...")
|
||||
score = evaluate_system_instruction(
|
||||
system_instruction=candidate,
|
||||
test_cases=test_cases,
|
||||
model_name=None
|
||||
)
|
||||
trajectory.append((candidate, score))
|
||||
print(f" Score: {score:.2%}")
|
||||
if score > best_score:
|
||||
print(f" 🎉 NEW BEST! Improved from {best_score:.2%} to {score:.2%}")
|
||||
best_score = score
|
||||
|
||||
# Final summary
|
||||
trajectory.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("🏁 Final Results")
|
||||
print("=" * 80)
|
||||
print(f"\n🏆 Best System Instruction (Score: {trajectory[0][1]:.2%}):")
|
||||
print(f"\n{trajectory[0][0]}")
|
||||
|
||||
print("\n📊 All Candidates Ranked:")
|
||||
for i, (candidate, score) in enumerate(trajectory[:5], 1):
|
||||
print(f"\n {i}. Score: {score:.2%}")
|
||||
print(f" {candidate[:80]}...")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("✅ OPRO Demo Complete!")
|
||||
print("=" * 80)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("\n⚠️ NOTE: This demo requires:")
|
||||
print(" 1. Ollama running locally (http://127.0.0.1:11434)")
|
||||
print(" 2. A Qwen model available (e.g., qwen3:8b)")
|
||||
print(" 3. An embedding model (e.g., qwen3-embedding:4b)")
|
||||
print("\n Press Ctrl+C to cancel, or Enter to continue...")
|
||||
|
||||
try:
|
||||
input()
|
||||
demo_opro_workflow()
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n❌ Demo cancelled by user.")
|
||||
sys.exit(0)
|
||||
|
||||
168
export-ollama-models.sh
Executable file
168
export-ollama-models.sh
Executable file
@@ -0,0 +1,168 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Export Ollama models for offline deployment
|
||||
# This script copies Ollama models from your local machine
|
||||
# so they can be bundled into the Docker image
|
||||
#
|
||||
# Required models:
|
||||
# - qwen3:14b (main chat model)
|
||||
# - qwen3-embedding:4b (embedding model)
|
||||
|
||||
set -e
|
||||
|
||||
MODELS_DIR="ollama-models"
|
||||
OLLAMA_MODELS_PATH="$HOME/.ollama"
|
||||
|
||||
echo "=========================================="
|
||||
echo "Exporting Ollama models for offline deployment"
|
||||
echo "=========================================="
|
||||
|
||||
# Check if Ollama is installed
|
||||
if ! command -v ollama &> /dev/null; then
|
||||
echo "ERROR: Ollama is not installed or not in PATH"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Check if required models are available
|
||||
echo ""
|
||||
echo "Checking for required models..."
|
||||
MISSING_MODELS=0
|
||||
|
||||
if ! ollama list | grep -q "qwen3:14b"; then
|
||||
echo "ERROR: qwen3:14b model not found!"
|
||||
echo "Please run: ollama pull qwen3:14b"
|
||||
MISSING_MODELS=1
|
||||
fi
|
||||
|
||||
if ! ollama list | grep -q "qwen3-embedding:4b"; then
|
||||
echo "ERROR: qwen3-embedding:4b model not found!"
|
||||
echo "Please run: ollama pull qwen3-embedding:4b"
|
||||
MISSING_MODELS=1
|
||||
fi
|
||||
|
||||
if [ $MISSING_MODELS -eq 1 ]; then
|
||||
echo ""
|
||||
echo "Please download the required models first:"
|
||||
echo " ollama pull qwen3:14b"
|
||||
echo " ollama pull qwen3-embedding:4b"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "✓ All required models found"
|
||||
|
||||
# Check if Ollama directory exists
|
||||
if [ ! -d "$OLLAMA_MODELS_PATH" ]; then
|
||||
echo "ERROR: Ollama directory not found at $OLLAMA_MODELS_PATH"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Create export directory structure
|
||||
echo ""
|
||||
echo "Creating export directory: $MODELS_DIR"
|
||||
rm -rf "$MODELS_DIR"
|
||||
mkdir -p "$MODELS_DIR/models/manifests/registry.ollama.ai/library"
|
||||
mkdir -p "$MODELS_DIR/models/blobs"
|
||||
|
||||
echo ""
|
||||
echo "Copying only required models (qwen3:14b and qwen3-embedding:4b)..."
|
||||
echo "This may take several minutes (models are large)..."
|
||||
|
||||
# Function to get blob hashes from manifest
|
||||
get_blobs_from_manifest() {
|
||||
local manifest_file=$1
|
||||
# Extract all sha256 hashes from the manifest JSON
|
||||
grep -oE 'sha256:[a-f0-9]{64}' "$manifest_file" 2>/dev/null | sed 's/sha256://' | sort -u
|
||||
}
|
||||
|
||||
# Function to copy model files
|
||||
copy_model() {
|
||||
local model_name=$1
|
||||
local model_tag=$2
|
||||
local manifest_dir="$OLLAMA_MODELS_PATH/models/manifests/registry.ollama.ai/library/$model_name"
|
||||
|
||||
if [ ! -d "$manifest_dir" ]; then
|
||||
echo "ERROR: Model manifest not found: $manifest_dir"
|
||||
return 1
|
||||
fi
|
||||
|
||||
echo " Copying $model_name:$model_tag manifest..."
|
||||
mkdir -p "$MODELS_DIR/models/manifests/registry.ollama.ai/library/$model_name"
|
||||
|
||||
# Copy the specific tag manifest
|
||||
if [ -f "$manifest_dir/$model_tag" ]; then
|
||||
cp "$manifest_dir/$model_tag" "$MODELS_DIR/models/manifests/registry.ollama.ai/library/$model_name/"
|
||||
|
||||
# Get all blob hashes referenced in this manifest
|
||||
echo " Finding blob files for $model_name:$model_tag..."
|
||||
local blob_hashes=$(get_blobs_from_manifest "$manifest_dir/$model_tag")
|
||||
local blob_count=0
|
||||
|
||||
for blob_hash in $blob_hashes; do
|
||||
local blob_file="$OLLAMA_MODELS_PATH/models/blobs/sha256-$blob_hash"
|
||||
if [ -f "$blob_file" ]; then
|
||||
cp "$blob_file" "$MODELS_DIR/models/blobs/" 2>/dev/null
|
||||
blob_count=$((blob_count + 1))
|
||||
fi
|
||||
done
|
||||
|
||||
echo " ✓ $model_name:$model_tag copied ($blob_count blobs)"
|
||||
else
|
||||
echo "ERROR: Manifest file not found: $manifest_dir/$model_tag"
|
||||
return 1
|
||||
fi
|
||||
}
|
||||
|
||||
# Copy required models with specific tags
|
||||
copy_model "qwen3" "14b" || exit 1
|
||||
copy_model "qwen3-embedding" "4b" || exit 1
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo "Models exported successfully!"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
echo "Total size:"
|
||||
du -sh "$MODELS_DIR"
|
||||
|
||||
echo ""
|
||||
echo "Models included:"
|
||||
if [ -d "$MODELS_DIR/models/manifests/registry.ollama.ai/library" ]; then
|
||||
ls -lh "$MODELS_DIR/models/manifests/registry.ollama.ai/library/"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "Blob files:"
|
||||
if [ -d "$MODELS_DIR/models/blobs" ]; then
|
||||
echo " Total blobs: $(ls -1 "$MODELS_DIR/models/blobs" | wc -l)"
|
||||
du -sh "$MODELS_DIR/models/blobs"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo "Summary"
|
||||
echo "=========================================="
|
||||
echo "✓ Only qwen3:14b and qwen3-embedding:4b were exported"
|
||||
echo ""
|
||||
echo "Models in your Ollama that were NOT copied:"
|
||||
ollama list | grep -v "qwen3:14b" | grep -v "qwen3-embedding:4b" | tail -n +2 || echo " (none)"
|
||||
echo ""
|
||||
echo "This keeps the Docker image size minimal!"
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo "Next steps:"
|
||||
echo "=========================================="
|
||||
echo "1. Build the all-in-one Docker image:"
|
||||
echo " ./build-allinone.sh"
|
||||
echo ""
|
||||
echo "2. Or manually:"
|
||||
echo " docker build -f Dockerfile.allinone -t system-prompt-optimizer:allinone ."
|
||||
echo ""
|
||||
echo "3. Export the image:"
|
||||
echo " docker save -o system-prompt-optimizer-allinone.tar system-prompt-optimizer:allinone"
|
||||
echo ""
|
||||
echo "4. Transfer to target server:"
|
||||
echo " scp system-prompt-optimizer-allinone.tar user@server:/path/"
|
||||
echo ""
|
||||
echo "Note: The final Docker image will be very large (10-20GB) due to the models."
|
||||
|
||||
664
frontend/opro.html
Normal file
664
frontend/opro.html
Normal file
@@ -0,0 +1,664 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<meta http-equiv="Cache-Control" content="no-cache, no-store, must-revalidate">
|
||||
<meta http-equiv="Pragma" content="no-cache">
|
||||
<meta http-equiv="Expires" content="0">
|
||||
<title>系统提示词优化</title>
|
||||
<script crossorigin src="https://unpkg.com/react@18/umd/react.production.min.js"></script>
|
||||
<script crossorigin src="https://unpkg.com/react-dom@18/umd/react-dom.production.min.js"></script>
|
||||
<script src="https://cdn.tailwindcss.com"></script>
|
||||
<style>
|
||||
body {
|
||||
margin: 0;
|
||||
font-family: 'Google Sans', 'Segoe UI', Roboto, sans-serif;
|
||||
background: #f8f9fa;
|
||||
}
|
||||
.chat-container { height: 100vh; display: flex; }
|
||||
.scrollbar-hide::-webkit-scrollbar { display: none; }
|
||||
.scrollbar-hide { -ms-overflow-style: none; scrollbar-width: none; }
|
||||
.sidebar-collapsed { width: 60px; }
|
||||
.sidebar-expanded { width: 260px; }
|
||||
.instruction-card {
|
||||
transition: all 0.15s ease;
|
||||
border: 1px solid #e8eaed;
|
||||
}
|
||||
.instruction-card:hover {
|
||||
border-color: #dadce0;
|
||||
box-shadow: 0 1px 3px rgba(60,64,67,0.15);
|
||||
}
|
||||
.loading-dots::after {
|
||||
content: '...';
|
||||
animation: dots 1.5s steps(4, end) infinite;
|
||||
}
|
||||
@keyframes dots {
|
||||
0%, 20% { content: '.'; }
|
||||
40% { content: '..'; }
|
||||
60%, 100% { content: '...'; }
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div id="root"></div>
|
||||
|
||||
<script>
|
||||
const { useState, useEffect, useRef } = React;
|
||||
const API_BASE = 'http://127.0.0.1:8010';
|
||||
|
||||
// Main App Component
|
||||
function App() {
|
||||
const [sidebarOpen, setSidebarOpen] = useState(false);
|
||||
const [sessions, setSessions] = useState([]);
|
||||
const [currentSessionId, setCurrentSessionId] = useState(null);
|
||||
const [currentSessionRuns, setCurrentSessionRuns] = useState([]);
|
||||
const [currentRunId, setCurrentRunId] = useState(null);
|
||||
const [messages, setMessages] = useState([]);
|
||||
const [sessionMessages, setSessionMessages] = useState({}); // Store messages per session
|
||||
const [sessionLastRunId, setSessionLastRunId] = useState({}); // Store last run ID per session
|
||||
const [inputValue, setInputValue] = useState('');
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [models, setModels] = useState([]);
|
||||
const [selectedModel, setSelectedModel] = useState('');
|
||||
const chatEndRef = useRef(null);
|
||||
|
||||
// Load sessions and models on mount
|
||||
useEffect(() => {
|
||||
loadSessions();
|
||||
loadModels();
|
||||
}, []);
|
||||
|
||||
async function loadModels() {
|
||||
try {
|
||||
const res = await fetch(`${API_BASE}/models`);
|
||||
const data = await res.json();
|
||||
if (data.success && data.data.models) {
|
||||
setModels(data.data.models);
|
||||
if (data.data.models.length > 0) {
|
||||
setSelectedModel(data.data.models[0]);
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
console.error('Failed to load models:', err);
|
||||
}
|
||||
}
|
||||
|
||||
// Auto-scroll chat
|
||||
useEffect(() => {
|
||||
chatEndRef.current?.scrollIntoView({ behavior: 'smooth' });
|
||||
}, [messages]);
|
||||
|
||||
async function loadSessions() {
|
||||
try {
|
||||
const res = await fetch(`${API_BASE}/opro/sessions`);
|
||||
const data = await res.json();
|
||||
if (data.success) {
|
||||
setSessions(data.data.sessions || []);
|
||||
}
|
||||
} catch (err) {
|
||||
console.error('Failed to load sessions:', err);
|
||||
}
|
||||
}
|
||||
|
||||
async function loadSessionRuns(sessionId) {
|
||||
try {
|
||||
const res = await fetch(`${API_BASE}/opro/session/${sessionId}`);
|
||||
const data = await res.json();
|
||||
if (data.success) {
|
||||
setCurrentSessionRuns(data.data.runs || []);
|
||||
}
|
||||
} catch (err) {
|
||||
console.error('Failed to load session runs:', err);
|
||||
}
|
||||
}
|
||||
|
||||
async function createNewSession() {
|
||||
try {
|
||||
const res = await fetch(`${API_BASE}/opro/session/create`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' }
|
||||
});
|
||||
const data = await res.json();
|
||||
|
||||
if (!data.success) {
|
||||
throw new Error(data.error || 'Failed to create session');
|
||||
}
|
||||
|
||||
const sessionId = data.data.session_id;
|
||||
setCurrentSessionId(sessionId);
|
||||
setCurrentSessionRuns([]);
|
||||
setCurrentRunId(null);
|
||||
setMessages([]);
|
||||
setSessionMessages(prev => ({ ...prev, [sessionId]: [] })); // Initialize empty messages for new session
|
||||
|
||||
// Reload sessions list
|
||||
await loadSessions();
|
||||
|
||||
return sessionId;
|
||||
} catch (err) {
|
||||
alert('创建会话失败: ' + err.message);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
async function createNewRun(taskDescription) {
|
||||
setLoading(true);
|
||||
try {
|
||||
// Ensure we have a session
|
||||
let sessionId = currentSessionId;
|
||||
if (!sessionId) {
|
||||
sessionId = await createNewSession();
|
||||
if (!sessionId) return;
|
||||
}
|
||||
|
||||
// Create run within session
|
||||
const res = await fetch(`${API_BASE}/opro/create`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
task_description: taskDescription,
|
||||
test_cases: [],
|
||||
model_name: selectedModel || undefined,
|
||||
session_id: sessionId
|
||||
})
|
||||
});
|
||||
const data = await res.json();
|
||||
|
||||
if (!data.success) {
|
||||
throw new Error(data.error || 'Failed to create run');
|
||||
}
|
||||
|
||||
const runId = data.data.run_id;
|
||||
setCurrentRunId(runId);
|
||||
|
||||
// Save this as the last run for this session
|
||||
setSessionLastRunId(prev => ({
|
||||
...prev,
|
||||
[sessionId]: runId
|
||||
}));
|
||||
|
||||
// Add user message to existing messages (keep chat history)
|
||||
const newUserMessage = { role: 'user', content: taskDescription };
|
||||
setMessages(prev => {
|
||||
const updated = [...prev, newUserMessage];
|
||||
// Save to session messages
|
||||
setSessionMessages(prevSessions => ({
|
||||
...prevSessions,
|
||||
[sessionId]: updated
|
||||
}));
|
||||
return updated;
|
||||
});
|
||||
|
||||
// Generate and evaluate candidates
|
||||
await generateCandidates(runId);
|
||||
|
||||
// Reload sessions and session runs
|
||||
await loadSessions();
|
||||
await loadSessionRuns(sessionId);
|
||||
} catch (err) {
|
||||
alert('创建任务失败: ' + err.message);
|
||||
console.error('Error creating run:', err);
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
}
|
||||
|
||||
async function generateCandidates(runId) {
|
||||
setLoading(true);
|
||||
try {
|
||||
console.log('Generating candidates for run:', runId);
|
||||
const res = await fetch(`${API_BASE}/opro/generate_and_evaluate`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
run_id: runId,
|
||||
top_k: 5,
|
||||
auto_evaluate: false // Use diversity-based selection
|
||||
})
|
||||
});
|
||||
const data = await res.json();
|
||||
|
||||
console.log('Generate candidates response:', data);
|
||||
|
||||
if (!data.success) {
|
||||
throw new Error(data.error || 'Failed to generate candidates');
|
||||
}
|
||||
|
||||
// Add assistant message with candidates
|
||||
const newAssistantMessage = {
|
||||
role: 'assistant',
|
||||
type: 'candidates',
|
||||
candidates: data.data.candidates,
|
||||
iteration: data.data.iteration
|
||||
};
|
||||
setMessages(prev => {
|
||||
const updated = [...prev, newAssistantMessage];
|
||||
// Save to session messages
|
||||
if (currentSessionId) {
|
||||
setSessionMessages(prevSessions => ({
|
||||
...prevSessions,
|
||||
[currentSessionId]: updated
|
||||
}));
|
||||
}
|
||||
return updated;
|
||||
});
|
||||
} catch (err) {
|
||||
alert('生成候选指令失败: ' + err.message);
|
||||
console.error('Error generating candidates:', err);
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
}
|
||||
|
||||
async function executeInstruction(instruction, userInput) {
|
||||
setLoading(true);
|
||||
try {
|
||||
const res = await fetch(`${API_BASE}/opro/execute`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
instruction: instruction,
|
||||
user_input: userInput || '请执行任务',
|
||||
model_name: selectedModel || undefined
|
||||
})
|
||||
});
|
||||
const data = await res.json();
|
||||
|
||||
if (!data.success) {
|
||||
throw new Error(data.error || 'Failed to execute');
|
||||
}
|
||||
|
||||
// Add execution result
|
||||
const newExecutionMessage = {
|
||||
role: 'assistant',
|
||||
type: 'execution',
|
||||
instruction: instruction,
|
||||
response: data.data.response
|
||||
};
|
||||
setMessages(prev => {
|
||||
const updated = [...prev, newExecutionMessage];
|
||||
// Save to session messages
|
||||
if (currentSessionId) {
|
||||
setSessionMessages(prevSessions => ({
|
||||
...prevSessions,
|
||||
[currentSessionId]: updated
|
||||
}));
|
||||
}
|
||||
return updated;
|
||||
});
|
||||
} catch (err) {
|
||||
alert('执行失败: ' + err.message);
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
}
|
||||
|
||||
function handleSendMessage() {
|
||||
const msg = inputValue.trim();
|
||||
if (!msg || loading) return;
|
||||
|
||||
setInputValue('');
|
||||
|
||||
// Always create a new run with the message as task description
|
||||
createNewRun(msg);
|
||||
}
|
||||
|
||||
async function handleContinueOptimize(selectedInstruction, allCandidates) {
|
||||
if (!currentRunId || loading || !selectedInstruction) return;
|
||||
|
||||
setLoading(true);
|
||||
try {
|
||||
// Get rejected instructions (all except the selected one)
|
||||
const rejectedInstructions = allCandidates
|
||||
.map(c => c.instruction)
|
||||
.filter(inst => inst !== selectedInstruction);
|
||||
|
||||
// Call the refinement endpoint
|
||||
const res = await fetch(`${API_BASE}/opro/refine`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
run_id: currentRunId,
|
||||
selected_instruction: selectedInstruction,
|
||||
rejected_instructions: rejectedInstructions
|
||||
})
|
||||
});
|
||||
const data = await res.json();
|
||||
|
||||
if (!data.success) {
|
||||
throw new Error(data.error || 'Failed to refine instruction');
|
||||
}
|
||||
|
||||
// Add refined candidates to messages
|
||||
const newMessage = {
|
||||
role: 'assistant',
|
||||
type: 'candidates',
|
||||
iteration: data.data.iteration,
|
||||
candidates: data.data.candidates
|
||||
};
|
||||
|
||||
setMessages(prev => {
|
||||
const updated = [...prev, newMessage];
|
||||
// Save to session messages
|
||||
setSessionMessages(prevSessions => ({
|
||||
...prevSessions,
|
||||
[currentSessionId]: updated
|
||||
}));
|
||||
return updated;
|
||||
});
|
||||
} catch (err) {
|
||||
alert('优化失败: ' + err.message);
|
||||
console.error('Error refining instruction:', err);
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
}
|
||||
|
||||
function handleExecute(instruction) {
|
||||
if (loading) return;
|
||||
executeInstruction(instruction, '');
|
||||
}
|
||||
|
||||
function handleCopyInstruction(instruction) {
|
||||
navigator.clipboard.writeText(instruction).then(() => {
|
||||
// Could add a toast notification here
|
||||
console.log('Instruction copied to clipboard');
|
||||
}).catch(err => {
|
||||
console.error('Failed to copy:', err);
|
||||
});
|
||||
}
|
||||
|
||||
function handleNewTask() {
|
||||
// Create new run within current session
|
||||
setCurrentRunId(null);
|
||||
setMessages([]);
|
||||
setInputValue('');
|
||||
}
|
||||
|
||||
async function handleNewSession() {
|
||||
// Create completely new session
|
||||
const sessionId = await createNewSession();
|
||||
if (sessionId) {
|
||||
setCurrentSessionId(sessionId);
|
||||
setCurrentSessionRuns([]);
|
||||
setCurrentRunId(null);
|
||||
setMessages([]);
|
||||
setInputValue('');
|
||||
}
|
||||
}
|
||||
|
||||
async function handleSelectSession(sessionId) {
|
||||
setCurrentSessionId(sessionId);
|
||||
// Restore the last run ID for this session
|
||||
setCurrentRunId(sessionLastRunId[sessionId] || null);
|
||||
// Load messages from session storage
|
||||
setMessages(sessionMessages[sessionId] || []);
|
||||
await loadSessionRuns(sessionId);
|
||||
}
|
||||
|
||||
async function loadRun(runId) {
|
||||
setLoading(true);
|
||||
try {
|
||||
const res = await fetch(`${API_BASE}/opro/run/${runId}`);
|
||||
const data = await res.json();
|
||||
|
||||
if (!data.success) {
|
||||
throw new Error(data.error || 'Failed to load run');
|
||||
}
|
||||
|
||||
const run = data.data;
|
||||
setCurrentRunId(runId);
|
||||
|
||||
// Reconstruct messages from run data
|
||||
const msgs = [
|
||||
{ role: 'user', content: run.task_description }
|
||||
];
|
||||
|
||||
if (run.current_candidates && run.current_candidates.length > 0) {
|
||||
msgs.push({
|
||||
role: 'assistant',
|
||||
type: 'candidates',
|
||||
candidates: run.current_candidates.map(c => ({ instruction: c, score: null })),
|
||||
iteration: run.iteration
|
||||
});
|
||||
}
|
||||
|
||||
setMessages(msgs);
|
||||
} catch (err) {
|
||||
alert('加载任务失败: ' + err.message);
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
}
|
||||
|
||||
return React.createElement('div', { className: 'chat-container' },
|
||||
// Sidebar
|
||||
React.createElement('div', {
|
||||
className: `bg-white border-r border-gray-200 transition-all duration-300 flex flex-col ${sidebarOpen ? 'sidebar-expanded' : 'sidebar-collapsed'}`
|
||||
},
|
||||
// Header area - Collapse button only
|
||||
React.createElement('div', { className: 'p-3 border-b border-gray-200 flex items-center justify-between' },
|
||||
sidebarOpen ? React.createElement('button', {
|
||||
onClick: () => setSidebarOpen(false),
|
||||
className: 'p-2 text-gray-600 hover:bg-gray-100 rounded-lg transition-colors'
|
||||
},
|
||||
React.createElement('svg', { width: '20', height: '20', viewBox: '0 0 24 24', fill: 'none', stroke: 'currentColor', strokeWidth: '2' },
|
||||
React.createElement('path', { d: 'M15 18l-6-6 6-6' })
|
||||
)
|
||||
) : React.createElement('button', {
|
||||
onClick: () => setSidebarOpen(true),
|
||||
className: 'w-full p-2 text-gray-600 hover:bg-gray-100 rounded-lg transition-colors flex items-center justify-center'
|
||||
},
|
||||
React.createElement('svg', { width: '20', height: '20', viewBox: '0 0 24 24', fill: 'none', stroke: 'currentColor', strokeWidth: '2' },
|
||||
React.createElement('path', { d: 'M3 12h18M3 6h18M3 18h18' })
|
||||
)
|
||||
)
|
||||
),
|
||||
// Content area
|
||||
React.createElement('div', { className: 'flex-1 overflow-y-auto scrollbar-hide p-2 flex flex-col' },
|
||||
sidebarOpen ? React.createElement(React.Fragment, null,
|
||||
// New session button (expanded)
|
||||
React.createElement('button', {
|
||||
onClick: handleNewSession,
|
||||
className: 'mb-3 px-4 py-2.5 bg-white border border-gray-300 hover:bg-gray-50 rounded-lg transition-colors flex items-center justify-center gap-2 text-gray-700 font-medium'
|
||||
},
|
||||
React.createElement('span', { className: 'text-lg' }, '+'),
|
||||
React.createElement('span', null, '新建会话')
|
||||
),
|
||||
// Sessions list
|
||||
sessions.length > 0 && React.createElement('div', { className: 'text-xs text-gray-500 mb-2 px-2' }, '会话列表'),
|
||||
sessions.map(session =>
|
||||
React.createElement('div', {
|
||||
key: session.session_id,
|
||||
onClick: () => handleSelectSession(session.session_id),
|
||||
className: `p-3 mb-1 rounded-lg cursor-pointer transition-colors flex items-center gap-2 ${
|
||||
currentSessionId === session.session_id ? 'bg-gray-100' : 'hover:bg-gray-50'
|
||||
}`
|
||||
},
|
||||
React.createElement('svg', {
|
||||
width: '16',
|
||||
height: '16',
|
||||
viewBox: '0 0 24 24',
|
||||
fill: 'none',
|
||||
stroke: 'currentColor',
|
||||
strokeWidth: '2',
|
||||
className: 'flex-shrink-0 text-gray-500'
|
||||
},
|
||||
React.createElement('path', { d: 'M21 15a2 2 0 0 1-2 2H7l-4 4V5a2 2 0 0 1 2-2h14a2 2 0 0 1 2 2z' })
|
||||
),
|
||||
React.createElement('div', { className: 'text-sm text-gray-800 truncate flex-1' },
|
||||
session.session_name
|
||||
)
|
||||
)
|
||||
)
|
||||
) : React.createElement('button', {
|
||||
onClick: handleNewSession,
|
||||
className: 'p-2 text-gray-600 hover:bg-gray-100 rounded-lg transition-colors flex items-center justify-center',
|
||||
title: '新建会话'
|
||||
},
|
||||
React.createElement('svg', { width: '24', height: '24', viewBox: '0 0 24 24', fill: 'none', stroke: 'currentColor', strokeWidth: '2' },
|
||||
React.createElement('path', { d: 'M12 5v14M5 12h14' })
|
||||
)
|
||||
)
|
||||
)
|
||||
),
|
||||
|
||||
// Main Chat Area
|
||||
React.createElement('div', { className: 'flex-1 flex flex-col bg-white' },
|
||||
// Header
|
||||
React.createElement('div', { className: 'px-4 py-3 border-b border-gray-200 bg-white flex items-center gap-3' },
|
||||
React.createElement('h1', { className: 'text-lg font-normal text-gray-800' },
|
||||
'系统提示词优化'
|
||||
),
|
||||
currentSessionId && React.createElement('div', { className: 'text-sm text-gray-500' },
|
||||
sessions.find(s => s.session_id === currentSessionId)?.session_name || '当前会话'
|
||||
)
|
||||
),
|
||||
|
||||
// Chat Messages
|
||||
React.createElement('div', { className: 'flex-1 overflow-y-auto scrollbar-hide p-6 space-y-6 max-w-4xl mx-auto w-full' },
|
||||
messages.map((msg, idx) => {
|
||||
if (msg.role === 'user') {
|
||||
return React.createElement('div', { key: idx, className: 'flex justify-end' },
|
||||
React.createElement('div', { className: 'max-w-2xl bg-gray-100 text-gray-800 rounded-2xl px-5 py-3' },
|
||||
msg.content
|
||||
)
|
||||
);
|
||||
} else if (msg.type === 'candidates') {
|
||||
return React.createElement('div', { key: idx, className: 'flex justify-start' },
|
||||
React.createElement('div', { className: 'w-full' },
|
||||
React.createElement('div', { className: 'mb-3' },
|
||||
React.createElement('div', { className: 'text-sm text-gray-600' },
|
||||
`优化后的提示词(第 ${msg.iteration} 轮)`
|
||||
),
|
||||
),
|
||||
msg.candidates.map((cand, cidx) =>
|
||||
React.createElement('div', {
|
||||
key: cidx,
|
||||
className: 'instruction-card bg-white rounded-xl p-5 mb-3'
|
||||
},
|
||||
React.createElement('div', { className: 'flex items-start gap-3' },
|
||||
React.createElement('div', { className: 'flex-shrink-0 w-7 h-7 bg-gray-200 text-gray-700 rounded-full flex items-center justify-center text-sm font-medium' },
|
||||
cidx + 1
|
||||
),
|
||||
React.createElement('div', { className: 'flex-1' },
|
||||
React.createElement('div', { className: 'text-gray-800 mb-4 whitespace-pre-wrap leading-relaxed' },
|
||||
cand.instruction
|
||||
),
|
||||
cand.score !== null && React.createElement('div', { className: 'text-xs text-gray-500 mb-3' },
|
||||
`评分: ${cand.score.toFixed(4)}`
|
||||
),
|
||||
React.createElement('div', { className: 'flex gap-2' },
|
||||
React.createElement('button', {
|
||||
onClick: () => handleContinueOptimize(cand.instruction, msg.candidates),
|
||||
disabled: loading,
|
||||
className: 'px-4 py-2 bg-white border border-gray-300 text-gray-700 rounded-lg hover:bg-gray-50 disabled:bg-gray-100 disabled:text-gray-400 disabled:cursor-not-allowed transition-colors text-sm font-medium'
|
||||
}, '继续优化'),
|
||||
React.createElement('button', {
|
||||
onClick: () => handleCopyInstruction(cand.instruction),
|
||||
className: 'px-4 py-2 bg-white border border-gray-300 text-gray-700 rounded-lg hover:bg-gray-50 transition-colors text-sm font-medium flex items-center gap-1'
|
||||
},
|
||||
React.createElement('svg', { width: '16', height: '16', viewBox: '0 0 24 24', fill: 'none', stroke: 'currentColor', strokeWidth: '2' },
|
||||
React.createElement('rect', { x: '9', y: '9', width: '13', height: '13', rx: '2', ry: '2' }),
|
||||
React.createElement('path', { d: 'M5 15H4a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2h9a2 2 0 0 1 2 2v1' })
|
||||
),
|
||||
'复制'
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
);
|
||||
} else if (msg.type === 'execution') {
|
||||
return React.createElement('div', { key: idx, className: 'flex justify-start' },
|
||||
React.createElement('div', { className: 'max-w-2xl bg-gray-50 border border-gray-200 rounded-2xl p-5' },
|
||||
React.createElement('div', { className: 'text-xs text-gray-600 mb-2 font-medium' },
|
||||
'执行结果'
|
||||
),
|
||||
React.createElement('div', { className: 'text-gray-800 whitespace-pre-wrap leading-relaxed' },
|
||||
msg.response
|
||||
)
|
||||
)
|
||||
);
|
||||
}
|
||||
}),
|
||||
loading && React.createElement('div', { className: 'flex justify-start' },
|
||||
React.createElement('div', { className: 'bg-gray-100 rounded-2xl px-5 py-3 text-gray-600' },
|
||||
React.createElement('span', { className: 'loading-dots' }, '思考中')
|
||||
)
|
||||
),
|
||||
React.createElement('div', { ref: chatEndRef })
|
||||
),
|
||||
|
||||
// Input Area
|
||||
React.createElement('div', { className: 'p-6 bg-white max-w-4xl mx-auto w-full' },
|
||||
React.createElement('div', { className: 'relative' },
|
||||
React.createElement('div', { className: 'bg-white border border-gray-300 rounded-3xl shadow-sm hover:shadow-md transition-shadow focus-within:shadow-md focus-within:border-gray-400' },
|
||||
// Textarea
|
||||
React.createElement('textarea', {
|
||||
value: inputValue,
|
||||
onChange: (e) => setInputValue(e.target.value),
|
||||
onKeyPress: (e) => {
|
||||
if (e.key === 'Enter' && !e.shiftKey) {
|
||||
e.preventDefault();
|
||||
handleSendMessage();
|
||||
}
|
||||
},
|
||||
placeholder: '输入任务描述,创建新的优化任务...',
|
||||
disabled: loading,
|
||||
rows: 3,
|
||||
className: 'w-full px-5 pt-4 pb-2 bg-transparent focus:outline-none disabled:bg-transparent text-gray-800 placeholder-gray-500 resize-none'
|
||||
}),
|
||||
// Toolbar
|
||||
React.createElement('div', { className: 'flex items-center justify-between px-4 pb-3 pt-1 border-t border-gray-100' },
|
||||
// Left side - Model selector
|
||||
React.createElement('div', { className: 'flex items-center gap-2' },
|
||||
React.createElement('label', { className: 'text-xs text-gray-600' }, '模型:'),
|
||||
React.createElement('select', {
|
||||
value: selectedModel,
|
||||
onChange: (e) => setSelectedModel(e.target.value),
|
||||
className: 'text-sm px-2 py-1 border border-gray-300 rounded-lg bg-white text-gray-700 focus:outline-none focus:border-gray-400 cursor-pointer'
|
||||
},
|
||||
models.map(model =>
|
||||
React.createElement('option', { key: model, value: model }, model)
|
||||
)
|
||||
)
|
||||
),
|
||||
// Right side - Send button
|
||||
React.createElement('button', {
|
||||
onClick: handleSendMessage,
|
||||
disabled: loading || !inputValue.trim(),
|
||||
className: 'p-2.5 bg-gray-100 text-gray-700 rounded-full hover:bg-gray-200 disabled:bg-gray-50 disabled:text-gray-300 disabled:cursor-not-allowed transition-colors flex items-center justify-center'
|
||||
},
|
||||
React.createElement('svg', {
|
||||
width: '20',
|
||||
height: '20',
|
||||
viewBox: '0 0 24 24',
|
||||
fill: 'currentColor'
|
||||
},
|
||||
React.createElement('path', { d: 'M2.01 21L23 12 2.01 3 2 10l15 2-15 2z' })
|
||||
)
|
||||
)
|
||||
)
|
||||
),
|
||||
React.createElement('div', { className: 'text-xs text-gray-500 mt-3 px-4' },
|
||||
currentSessionId
|
||||
? '输入任务描述,AI 将为你生成优化的系统指令'
|
||||
: '点击左侧"新建会话"开始,或直接输入任务描述自动创建会话'
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
// Render App
|
||||
const root = ReactDOM.createRoot(document.getElementById('root'));
|
||||
root.render(React.createElement(App));
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
|
||||
Binary file not shown.
Binary file not shown.
File diff suppressed because it is too large
Load Diff
@@ -1,736 +0,0 @@
|
||||
|
||||
import datetime
|
||||
import functools
|
||||
import os
|
||||
import sys
|
||||
|
||||
OPRO_ROOT_PATH = os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
)
|
||||
sys.path.insert(0, OPRO_ROOT_PATH)
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
import google.generativeai as palm
|
||||
import numpy as np
|
||||
import openai
|
||||
from opro import prompt_utils
|
||||
from opro.optimization import opt_utils
|
||||
import pandas as pd
|
||||
FLAGS = flags.FLAGS
|
||||
ROOT_DATA_FOLDER_PATH = os.path.join(OPRO_ROOT_PATH, "data")
|
||||
|
||||
flags.DEFINE_string("local_model_path", "", "Path to local vLLM model.")
|
||||
|
||||
_OPENAI_API_KEY = flags.DEFINE_string(
|
||||
"openai_api_key", "", "The OpenAI API key."
|
||||
)
|
||||
|
||||
_PALM_API_KEY = flags.DEFINE_string("palm_api_key", "", "The PaLM API key.")
|
||||
|
||||
_SCORER = flags.DEFINE_string(
|
||||
"scorer", "text-bison", "The name of the scorer LLM."
|
||||
)
|
||||
|
||||
_OPTIMIZER = flags.DEFINE_string(
|
||||
"optimizer", "gpt-3.5-turbo", "The name of the optimizer LLM."
|
||||
)
|
||||
|
||||
_DATASET = flags.DEFINE_string(
|
||||
"dataset", "gsm8k", "The name of dataset to search for instructions on."
|
||||
)
|
||||
|
||||
_TASK = flags.DEFINE_string(
|
||||
"task",
|
||||
"train",
|
||||
"The name of task within the above dataset to search for instructions on.",
|
||||
)
|
||||
|
||||
_INSTRUCTION_POS = flags.DEFINE_string(
|
||||
"instruction_pos",
|
||||
"A_begin",
|
||||
"The position of the instruction to search for.",
|
||||
)
|
||||
|
||||
_META_PROMPT_TYPE = flags.DEFINE_string(
|
||||
"meta_prompt_type",
|
||||
"both_instructions_and_exemplars",
|
||||
"The type of meta-prompt: whether to have both previous instructions and"
|
||||
" dataset exemplars (often for fine-tuned optimizers), or to have only"
|
||||
" previous instructions (often for pre-trained optimizers).",
|
||||
)
|
||||
|
||||
|
||||
def main(_):
|
||||
local_model_path = FLAGS.local_model_path
|
||||
openai_api_key = _OPENAI_API_KEY.value
|
||||
palm_api_key = _PALM_API_KEY.value
|
||||
scorer_llm_name = _SCORER.value
|
||||
optimizer_llm_name = _OPTIMIZER.value
|
||||
dataset_name = _DATASET.value.lower()
|
||||
task_name = _TASK.value
|
||||
meta_prompt_type = _META_PROMPT_TYPE.value
|
||||
|
||||
assert dataset_name in {
|
||||
"mmlu",
|
||||
"bbh",
|
||||
"gsm8k",
|
||||
}, "The lower-case dataset name must be one of mmlu, bbh, or gsm8k."
|
||||
if dataset_name == "mmlu":
|
||||
assert task_name in {
|
||||
"STEM",
|
||||
"humanities",
|
||||
"social sciences",
|
||||
"other (business, health, misc.)",
|
||||
} # for now only support searching on one MMLU category
|
||||
elif dataset_name == "bbh":
|
||||
assert task_name in {
|
||||
"boolean_expressions",
|
||||
"causal_judgement",
|
||||
"date_understanding",
|
||||
"disambiguation_qa",
|
||||
"dyck_languages",
|
||||
"formal_fallacies",
|
||||
"geometric_shapes",
|
||||
"hyperbaton",
|
||||
"logical_deduction_five_objects",
|
||||
"logical_deduction_seven_objects",
|
||||
"logical_deduction_three_objects",
|
||||
"movie_recommendation",
|
||||
"multistep_arithmetic_two",
|
||||
"navigate",
|
||||
"object_counting",
|
||||
"penguins_in_a_table",
|
||||
"reasoning_about_colored_objects",
|
||||
"ruin_names",
|
||||
"salient_translation_error_detection",
|
||||
"snarks",
|
||||
"sports_understanding",
|
||||
"temporal_sequences",
|
||||
"tracking_shuffled_objects_five_objects",
|
||||
"tracking_shuffled_objects_seven_objects",
|
||||
"tracking_shuffled_objects_three_objects",
|
||||
"web_of_lies",
|
||||
"word_sorting",
|
||||
}
|
||||
else:
|
||||
assert dataset_name == "gsm8k"
|
||||
assert task_name in {"train", "test"}
|
||||
|
||||
assert scorer_llm_name in {
|
||||
"text-bison",
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-4",
|
||||
"local",
|
||||
}
|
||||
assert optimizer_llm_name in {
|
||||
"text-bison",
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-4",
|
||||
"local",
|
||||
}
|
||||
assert meta_prompt_type in {
|
||||
"both_instructions_and_exemplars",
|
||||
"instructions_only",
|
||||
}
|
||||
|
||||
instruction_pos = _INSTRUCTION_POS.value
|
||||
assert instruction_pos in {
|
||||
"before_Q",
|
||||
"Q_begin",
|
||||
"Q_end",
|
||||
"A_begin",
|
||||
}, (
|
||||
"The instruction position should be either before the question, or at the"
|
||||
" beginning of the question, at the end of the question, or at the"
|
||||
" beginning of the answer."
|
||||
)
|
||||
print(
|
||||
f"scorer: {scorer_llm_name}, optimizer: {optimizer_llm_name}, dataset:"
|
||||
f" {dataset_name}, task: {task_name}, instruction_pos: {instruction_pos}"
|
||||
)
|
||||
|
||||
|
||||
if scorer_llm_name in {"gpt-3.5-turbo", "gpt-4"}:
|
||||
assert openai_api_key, "The OpenAI API key must be provided."
|
||||
openai.api_key = openai_api_key
|
||||
elif scorer_llm_name == "text-bison":
|
||||
assert palm_api_key, "A PaLM API key is needed when prompting the text-bison model."
|
||||
palm.configure(api_key=palm_api_key)
|
||||
elif scorer_llm_name == "local":
|
||||
# 本地模型,无需 API key
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"Unknown scorer model: {scorer_llm_name}")
|
||||
|
||||
if optimizer_llm_name in {"gpt-3.5-turbo", "gpt-4"}:
|
||||
assert openai_api_key, "The OpenAI API key must be provided."
|
||||
openai.api_key = openai_api_key
|
||||
elif optimizer_llm_name == "text-bison":
|
||||
assert palm_api_key, "A PaLM API key is needed when prompting the text-bison model."
|
||||
palm.configure(api_key=palm_api_key)
|
||||
elif optimizer_llm_name == "local":
|
||||
# 本地模型,无需 API key
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"Unknown scorer model: {optimizer_llm_name}")
|
||||
|
||||
|
||||
if dataset_name == "mmlu":
|
||||
root_data_folder_path = os.path.join(ROOT_DATA_FOLDER_PATH, "MMLU-data")
|
||||
elif dataset_name == "bbh":
|
||||
root_data_folder_path = os.path.join(
|
||||
ROOT_DATA_FOLDER_PATH, "BIG-Bench-Hard-data/"
|
||||
)
|
||||
else:
|
||||
assert dataset_name == "gsm8k"
|
||||
root_data_folder_path = os.path.join(ROOT_DATA_FOLDER_PATH, "gsm_data")
|
||||
|
||||
# =================== create the result directory ==========================
|
||||
datetime_str = (
|
||||
str(datetime.datetime.now().replace(microsecond=0))
|
||||
.replace(" ", "-")
|
||||
.replace(":", "-")
|
||||
)
|
||||
|
||||
|
||||
save_folder = os.path.join(
|
||||
OPRO_ROOT_PATH,
|
||||
"outputs",
|
||||
"optimization-results",
|
||||
f"{dataset_name.upper()}-{task_name}-s-{scorer_llm_name}-o-{optimizer_llm_name}-{datetime_str}/",
|
||||
)
|
||||
|
||||
|
||||
result_by_instruction_folder = os.path.join(
|
||||
save_folder, "result_by_instruction"
|
||||
)
|
||||
print(f"Results will be saved to: {os.path.abspath(result_by_instruction_folder)}")
|
||||
os.makedirs(result_by_instruction_folder,exist_ok=True)
|
||||
print(f"result directory:\n{save_folder}")
|
||||
|
||||
# ====================== scorer model configs ==============================
|
||||
|
||||
|
||||
if scorer_llm_name == "text-bison":
|
||||
# when prompting text-bison with Cloud API
|
||||
scorer_finetuned_palm_temperature = 0.0
|
||||
scorer_finetuned_palm_max_decode_steps = 1024
|
||||
scorer_finetuned_palm_batch_size = 1
|
||||
scorer_finetuned_palm_num_servers = 1
|
||||
scorer_finetuned_palm_dict = dict()
|
||||
scorer_finetuned_palm_dict["temperature"] = (
|
||||
scorer_finetuned_palm_temperature
|
||||
)
|
||||
scorer_finetuned_palm_dict["num_servers"] = (
|
||||
scorer_finetuned_palm_num_servers
|
||||
)
|
||||
scorer_finetuned_palm_dict["batch_size"] = scorer_finetuned_palm_batch_size
|
||||
scorer_finetuned_palm_dict["max_decode_steps"] = (
|
||||
scorer_finetuned_palm_max_decode_steps
|
||||
)
|
||||
|
||||
call_scorer_finetuned_palm_server_func = functools.partial(
|
||||
prompt_utils.call_palm_server_from_cloud,
|
||||
model="text-bison-001",
|
||||
temperature=scorer_finetuned_palm_dict["temperature"],
|
||||
max_decode_steps=scorer_finetuned_palm_dict["max_decode_steps"],
|
||||
)
|
||||
|
||||
scorer_llm_dict = {
|
||||
"model_type": scorer_llm_name.lower(),
|
||||
}
|
||||
scorer_llm_dict.update(scorer_finetuned_palm_dict)
|
||||
call_scorer_server_func = call_scorer_finetuned_palm_server_func
|
||||
|
||||
|
||||
elif scorer_llm_name.lower() in {"gpt-3.5-turbo", "gpt-4", "local"}:
|
||||
# 改成调用本地vLLM版本的函数
|
||||
scorer_gpt_max_decode_steps = 1024
|
||||
# scorer_gpt_max_decode_steps = 512
|
||||
scorer_gpt_temperature = 0.0
|
||||
|
||||
scorer_llm_dict = {
|
||||
"model_type": scorer_llm_name.lower(),
|
||||
"max_decode_steps": scorer_gpt_max_decode_steps,
|
||||
"temperature": scorer_gpt_temperature,
|
||||
"num_decodes": 1,
|
||||
"batch_size": 1,
|
||||
"num_servers": 1,
|
||||
}
|
||||
|
||||
call_scorer_server_func = functools.partial(
|
||||
prompt_utils.call_openai_server_func, # 你本地实现的vLLM调用函数
|
||||
max_decode_steps=scorer_gpt_max_decode_steps,
|
||||
temperature=scorer_gpt_temperature,
|
||||
local_model_path=FLAGS.local_model_path, # 传入你本地模型路径
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported scorer_llm_name: {scorer_llm_name}")
|
||||
|
||||
|
||||
# ====================== optimizer model configs ============================
|
||||
if optimizer_llm_name.lower() == "text-bison":
|
||||
# when prompting text-bison with Cloud API
|
||||
optimizer_finetuned_palm_temperature = 1.0
|
||||
optimizer_finetuned_palm_num_decodes = 8
|
||||
optimizer_finetuned_palm_max_decode_steps = 1024
|
||||
optimizer_finetuned_palm_batch_size = 1
|
||||
optimizer_finetuned_palm_num_servers = 1
|
||||
optimizer_finetuned_palm_dict = dict()
|
||||
optimizer_finetuned_palm_dict["temperature"] = (
|
||||
optimizer_finetuned_palm_temperature
|
||||
)
|
||||
optimizer_finetuned_palm_dict["num_decodes"] = (
|
||||
optimizer_finetuned_palm_num_decodes
|
||||
)
|
||||
optimizer_finetuned_palm_dict["batch_size"] = (
|
||||
optimizer_finetuned_palm_batch_size
|
||||
)
|
||||
optimizer_finetuned_palm_dict["num_servers"] = (
|
||||
optimizer_finetuned_palm_num_servers
|
||||
)
|
||||
optimizer_finetuned_palm_dict["max_decode_steps"] = (
|
||||
optimizer_finetuned_palm_max_decode_steps
|
||||
)
|
||||
|
||||
call_optimizer_finetuned_palm_server_func = functools.partial(
|
||||
prompt_utils.call_palm_server_from_cloud,
|
||||
model="text-bison-001",
|
||||
temperature=optimizer_finetuned_palm_dict["temperature"],
|
||||
max_decode_steps=optimizer_finetuned_palm_dict["max_decode_steps"],
|
||||
)
|
||||
|
||||
optimizer_llm_dict = {
|
||||
"model_type": optimizer_llm_name.lower(),
|
||||
}
|
||||
optimizer_llm_dict.update(optimizer_finetuned_palm_dict)
|
||||
call_optimizer_server_func = call_optimizer_finetuned_palm_server_func
|
||||
|
||||
elif optimizer_llm_name.lower() in {"gpt-3.5-turbo", "gpt-4", "local"}:
|
||||
# 用本地 vLLM 版本替代调用
|
||||
optimizer_gpt_max_decode_steps = 512
|
||||
|
||||
optimizer_gpt_temperature = 1.0
|
||||
|
||||
optimizer_llm_dict = {
|
||||
"max_decode_steps": optimizer_gpt_max_decode_steps,
|
||||
"temperature": optimizer_gpt_temperature,
|
||||
"batch_size": 1,
|
||||
"num_decodes": 1,
|
||||
}
|
||||
|
||||
call_optimizer_server_func = functools.partial(
|
||||
prompt_utils.call_openai_server_func, # 你写的本地vLLM调用接口
|
||||
max_decode_steps=optimizer_gpt_max_decode_steps,
|
||||
temperature=optimizer_gpt_temperature,
|
||||
local_model_path=FLAGS.local_model_path,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported optimizer_llm_name: {optimizer_llm_name}")
|
||||
|
||||
|
||||
|
||||
# ====================== try calling the servers ============================
|
||||
print("\n======== testing the scorer and optimizer servers ===========")
|
||||
scorer_test_output = call_scorer_server_func(
|
||||
"Does the sun rise from the north? Just answer yes or no."
|
||||
)
|
||||
print(f"number of scorer output decodes: {len(scorer_test_output)}")
|
||||
print(f"scorer test output: {scorer_test_output}")
|
||||
optimizer_test_output = call_optimizer_server_func(
|
||||
"Does the sun rise from the north? Just answer yes or no.",
|
||||
temperature=1.0,
|
||||
)
|
||||
print(f"number of optimizer output decodes: {len(optimizer_test_output)}")
|
||||
print(f"optimizer test output: {optimizer_test_output}")
|
||||
print("Finished testing the servers.")
|
||||
|
||||
# ====================== read data ============================
|
||||
print("\n================ prompt optimization settings ==============")
|
||||
# from https://github.com/hendrycks/test/blob/master/categories.py
|
||||
subcategories = {
|
||||
"abstract_algebra": ["math"],
|
||||
"anatomy": ["health"],
|
||||
"astronomy": ["physics"],
|
||||
"business_ethics": ["business"],
|
||||
"clinical_knowledge": ["health"],
|
||||
"college_biology": ["biology"],
|
||||
"college_chemistry": ["chemistry"],
|
||||
"college_computer_science": ["computer science"],
|
||||
"college_mathematics": ["math"],
|
||||
"college_medicine": ["health"],
|
||||
"college_physics": ["physics"],
|
||||
"computer_security": ["computer science"],
|
||||
"conceptual_physics": ["physics"],
|
||||
"econometrics": ["economics"],
|
||||
"electrical_engineering": ["engineering"],
|
||||
"elementary_mathematics": ["math"],
|
||||
"formal_logic": ["philosophy"],
|
||||
"global_facts": ["other"],
|
||||
"high_school_biology": ["biology"],
|
||||
"high_school_chemistry": ["chemistry"],
|
||||
"high_school_computer_science": ["computer science"],
|
||||
"high_school_european_history": ["history"],
|
||||
"high_school_geography": ["geography"],
|
||||
"high_school_government_and_politics": ["politics"],
|
||||
"high_school_macroeconomics": ["economics"],
|
||||
"high_school_mathematics": ["math"],
|
||||
"high_school_microeconomics": ["economics"],
|
||||
"high_school_physics": ["physics"],
|
||||
"high_school_psychology": ["psychology"],
|
||||
"high_school_statistics": ["math"],
|
||||
"high_school_us_history": ["history"],
|
||||
"high_school_world_history": ["history"],
|
||||
"human_aging": ["health"],
|
||||
"human_sexuality": ["culture"],
|
||||
"international_law": ["law"],
|
||||
"jurisprudence": ["law"],
|
||||
"logical_fallacies": ["philosophy"],
|
||||
"machine_learning": ["computer science"],
|
||||
"management": ["business"],
|
||||
"marketing": ["business"],
|
||||
"medical_genetics": ["health"],
|
||||
"miscellaneous": ["other"],
|
||||
"moral_disputes": ["philosophy"],
|
||||
"moral_scenarios": ["philosophy"],
|
||||
"nutrition": ["health"],
|
||||
"philosophy": ["philosophy"],
|
||||
"prehistory": ["history"],
|
||||
"professional_accounting": ["other"],
|
||||
"professional_law": ["law"],
|
||||
"professional_medicine": ["health"],
|
||||
"professional_psychology": ["psychology"],
|
||||
"public_relations": ["politics"],
|
||||
"security_studies": ["politics"],
|
||||
"sociology": ["culture"],
|
||||
"us_foreign_policy": ["politics"],
|
||||
"virology": ["health"],
|
||||
"world_religions": ["philosophy"],
|
||||
}
|
||||
|
||||
categories = {
|
||||
"STEM": [
|
||||
"physics",
|
||||
"chemistry",
|
||||
"biology",
|
||||
"computer science",
|
||||
"math",
|
||||
"engineering",
|
||||
],
|
||||
"humanities": ["history", "philosophy", "law"],
|
||||
"social sciences": [
|
||||
"politics",
|
||||
"culture",
|
||||
"economics",
|
||||
"geography",
|
||||
"psychology",
|
||||
],
|
||||
"other (business, health, misc.)": ["other", "business", "health"],
|
||||
}
|
||||
|
||||
if dataset_name == "mmlu":
|
||||
|
||||
category_names = [task_name]
|
||||
folder_name = "test" # one of {'auxiliary_train', 'dev', 'val', 'test'}
|
||||
task_names = []
|
||||
for task_csv_name in os.listdir(
|
||||
os.path.join(root_data_folder_path, folder_name)
|
||||
):
|
||||
task_names.append(task_csv_name.split(".")[0])
|
||||
|
||||
tasks_in_category = []
|
||||
for category_name in category_names:
|
||||
for task_name in task_names:
|
||||
for subname in subcategories:
|
||||
if subname in task_name:
|
||||
if subcategories[subname][0] in categories[category_name]:
|
||||
tasks_in_category.append(task_name)
|
||||
break
|
||||
|
||||
tasks_all = [(folder_name, task_name) for task_name in tasks_in_category]
|
||||
multiple_choice_tasks = set([item[1] for item in tasks_all])
|
||||
boolean_tasks = set()
|
||||
numerical_output_tasks = set()
|
||||
|
||||
|
||||
elif dataset_name == "bbh":
|
||||
tasks_all = [task_name]
|
||||
assert (
|
||||
len(tasks_all) == 1
|
||||
), "for now only support prompt optimization on one BBH task"
|
||||
|
||||
|
||||
numerical_output_tasks = {
|
||||
"object_counting",
|
||||
"multistep_arithmetic_two",
|
||||
}
|
||||
|
||||
multiple_choice_tasks = {
|
||||
"date_understanding",
|
||||
"disambiguation_qa",
|
||||
"geometric_shapes",
|
||||
"hyperbaton",
|
||||
"logical_deduction_five_objects",
|
||||
"logical_deduction_seven_objects",
|
||||
"logical_deduction_three_objects",
|
||||
"movie_recommendation",
|
||||
"penguins_in_a_table",
|
||||
"reasoning_about_colored_objects",
|
||||
"ruin_names",
|
||||
"salient_translation_error_detection",
|
||||
"snarks",
|
||||
"temporal_sequences",
|
||||
"tracking_shuffled_objects_five_objects",
|
||||
"tracking_shuffled_objects_seven_objects",
|
||||
"tracking_shuffled_objects_three_objects",
|
||||
}
|
||||
|
||||
boolean_tasks = {
|
||||
"boolean_expressions", # True or False
|
||||
"causal_judgement", # yes or no
|
||||
"formal_fallacies", # valid or invalid
|
||||
"navigate", # yes or no
|
||||
"sports_understanding", # yes or no
|
||||
"web_of_lies", # yes or no
|
||||
}
|
||||
|
||||
else:
|
||||
assert dataset_name in {"gsm8k"}
|
||||
tasks_all = [task_name]
|
||||
multiple_choice_tasks = set()
|
||||
boolean_tasks = set()
|
||||
numerical_output_tasks = set(tasks_all)
|
||||
|
||||
if dataset_name == "mmlu":
|
||||
raw_data = pd.DataFrame()
|
||||
prediction_treat_as_number = False
|
||||
prediction_treat_as_bool = False
|
||||
elif dataset_name == "bbh":
|
||||
raw_data = []
|
||||
prediction_treat_as_number = bool(
|
||||
tasks_all[0] in numerical_output_tasks
|
||||
) # for now only check the first task
|
||||
prediction_treat_as_bool = bool(
|
||||
tasks_all[0] in boolean_tasks
|
||||
) # for now only check the first task
|
||||
print(
|
||||
f"prediction_treat_as_number: {prediction_treat_as_number},"
|
||||
f" prediction_treat_as_bool: {prediction_treat_as_bool}"
|
||||
)
|
||||
else:
|
||||
assert dataset_name == "gsm8k"
|
||||
raw_data = pd.DataFrame()
|
||||
prediction_treat_as_number = True
|
||||
prediction_treat_as_bool = False
|
||||
|
||||
for t in tasks_all:
|
||||
if dataset_name == "mmlu":
|
||||
folder_name = t[0]
|
||||
task_name = t[1]
|
||||
single_task_df = pd.read_csv(
|
||||
os.path.join(root_data_folder_path, f"{folder_name}/{task_name}.csv"),
|
||||
index_col=None,
|
||||
header=None,
|
||||
)
|
||||
raw_data = pd.concat([raw_data, single_task_df])
|
||||
elif dataset_name == "bbh":
|
||||
task_name = t
|
||||
single_task_list = opt_utils.load_bbh_task_data(
|
||||
task_name, base_dir=root_data_folder_path
|
||||
)
|
||||
raw_data += single_task_list
|
||||
else:
|
||||
assert dataset_name == "gsm8k"
|
||||
task_name = t
|
||||
f_gsm = os.path.join(root_data_folder_path, f"gsm_{task_name}.tsv")
|
||||
single_task_df = pd.read_csv(f_gsm, sep="\t", header=None)
|
||||
raw_data = pd.concat([raw_data, single_task_df])
|
||||
|
||||
if dataset_name == "mmlu":
|
||||
num_examples = raw_data.shape[0]
|
||||
elif dataset_name == "bbh":
|
||||
num_examples = len(raw_data)
|
||||
else:
|
||||
assert dataset_name in {"gsm8k"}
|
||||
num_examples = raw_data.shape[0]
|
||||
print(f"number of examples in the current task: {num_examples}")
|
||||
|
||||
# ================ split data into train/val/test ==========================
|
||||
if dataset_name == "mmlu":
|
||||
train_ratio = 0.8
|
||||
|
||||
eval_ratio = 0.2
|
||||
elif dataset_name == "gsm8k":
|
||||
# train_ratio = 0.035
|
||||
train_ratio = 0.01 # 原来是 0.035,改成 0.01,约 74 条
|
||||
eval_ratio = 0
|
||||
else:
|
||||
assert dataset_name == "bbh"
|
||||
train_ratio = 0.2
|
||||
eval_ratio = 0
|
||||
|
||||
|
||||
assert train_ratio + eval_ratio <= 1
|
||||
test_ratio = 1 - train_ratio - eval_ratio
|
||||
print(
|
||||
f"train_ratio: {train_ratio}, eval_ratio: {eval_ratio}, "
|
||||
f"test_ratio: {test_ratio}"
|
||||
)
|
||||
np.random.seed(0)
|
||||
train_index = np.sort(
|
||||
np.array(
|
||||
np.random.choice(
|
||||
num_examples, size=int(train_ratio * num_examples), replace=False
|
||||
)
|
||||
)
|
||||
)
|
||||
eval_and_test_index = np.sort(
|
||||
np.array(list(set(np.arange(num_examples)) - set(train_index)))
|
||||
)
|
||||
eval_index = np.sort(
|
||||
np.array(
|
||||
np.random.choice(
|
||||
eval_and_test_index,
|
||||
size=int(eval_ratio * num_examples),
|
||||
replace=False,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# ========== set other optimization experiment hyperparameters ==============
|
||||
if scorer_llm_name == "text-bison":
|
||||
old_instruction_score_threshold = 0.0
|
||||
# old_instruction_score_threshold = 0.15 # for GSM8K
|
||||
else:
|
||||
assert scorer_llm_name in {"gpt-3.5-turbo", "gpt-4", "local"}
|
||||
old_instruction_score_threshold = 0.3
|
||||
|
||||
if scorer_llm_name == "text-bison":
|
||||
extract_final_answer_by_prompting_again = False
|
||||
include_qa = False
|
||||
evaluate_in_parallel = False
|
||||
else:
|
||||
assert scorer_llm_name in {"gpt-3.5-turbo", "gpt-4", "local"}
|
||||
extract_final_answer_by_prompting_again = False
|
||||
include_qa = False
|
||||
evaluate_in_parallel = False
|
||||
|
||||
optimizer_llm_temperature = optimizer_llm_dict["temperature"]
|
||||
|
||||
|
||||
# num_few_shot_questions_for_instruction_refinement = 3
|
||||
num_few_shot_questions_for_instruction_refinement = 1 # 减少 few-shot 例子数
|
||||
|
||||
# num_generated_instructions_in_each_step = 8
|
||||
num_generated_instructions_in_each_step = 2 # 每步只生成 2 条指令
|
||||
|
||||
# num_search_steps = 200
|
||||
num_search_steps = 3 # 原来是 200,改成 3 步即可
|
||||
|
||||
|
||||
initial_instructions = [
|
||||
"Let's solve the problem.",
|
||||
# "",
|
||||
# "The answer is",
|
||||
]
|
||||
few_shot_qa_pairs = True
|
||||
# one of {'accumulative_most_frequent', 'current_most_frequent', 'random',
|
||||
# 'constant'}
|
||||
few_shot_selection_criteria = "random"
|
||||
# whether to evaluate generated instructions on the exemplars in meta-prompt
|
||||
evaluate_generated_ins_on_few_shot = False
|
||||
# whether to evaluate old instructions on the exemplars in the meta-prompt
|
||||
evaluate_old_ins_on_few_shot = False
|
||||
# every this number of steps, compute the accuracies of current-step
|
||||
# instructions on the validation set
|
||||
# eval_interval = 3
|
||||
eval_interval = 1 # 每步就 eval 一次,及时看到结果
|
||||
# eval_interval = 10
|
||||
max_num_instructions = (
|
||||
20 # the maximum number of instructions and scores in the meta-prompt
|
||||
)
|
||||
# The number of buckets when converting scores to integers in the meta-prompt.
|
||||
num_score_buckets = 100
|
||||
# whether to put old instructions and scores to before exemplars in
|
||||
# the meta-prompt
|
||||
meta_prompt_instructions_before_exemplars = True
|
||||
|
||||
# ===================== run prompt optimization ======================
|
||||
|
||||
assert few_shot_selection_criteria in {
|
||||
"accumulative_most_frequent",
|
||||
"current_most_frequent",
|
||||
"random",
|
||||
"constant",
|
||||
}
|
||||
evolution_kwargs = {
|
||||
"num_search_steps": num_search_steps,
|
||||
"old_instruction_score_threshold": old_instruction_score_threshold,
|
||||
"scorer_llm_dict": scorer_llm_dict,
|
||||
"optimizer_llm_dict": optimizer_llm_dict,
|
||||
"extract_final_answer_by_prompting_again": (
|
||||
extract_final_answer_by_prompting_again
|
||||
),
|
||||
"include_qa": include_qa,
|
||||
"evaluate_in_parallel": evaluate_in_parallel,
|
||||
"tasks_all": tasks_all,
|
||||
"train_ratio": train_ratio,
|
||||
"eval_ratio": eval_ratio,
|
||||
"test_ratio": test_ratio,
|
||||
"train_index": train_index,
|
||||
"eval_index": eval_index,
|
||||
"dataset_name": dataset_name,
|
||||
"task_name": task_name,
|
||||
"num_examples": num_examples,
|
||||
"root_data_folder_path": root_data_folder_path,
|
||||
"optimizer_llm_temperature": optimizer_llm_temperature,
|
||||
# "optimizer_llm_temperature_schedule": (
|
||||
# optimizer_llm_temperature_schedule
|
||||
# ),
|
||||
# "optimizer_llm_temperature_end": optimizer_llm_temperature_end,
|
||||
"initial_instructions": initial_instructions,
|
||||
"multiple_choice_tasks": multiple_choice_tasks,
|
||||
"raw_data": raw_data,
|
||||
"call_scorer_server_func": call_scorer_server_func,
|
||||
"call_optimizer_server_func": call_optimizer_server_func,
|
||||
"instruction_pos": instruction_pos,
|
||||
"prediction_treat_as_number": prediction_treat_as_number,
|
||||
"prediction_treat_as_bool": prediction_treat_as_bool,
|
||||
"result_by_instruction_folder": result_by_instruction_folder,
|
||||
"few_shot_qa_pairs": few_shot_qa_pairs,
|
||||
"num_score_buckets": num_score_buckets,
|
||||
"max_num_instructions": max_num_instructions,
|
||||
"meta_prompt_type": meta_prompt_type,
|
||||
"meta_prompt_instructions_before_exemplars": (
|
||||
meta_prompt_instructions_before_exemplars
|
||||
),
|
||||
"few_shot_selection_criteria": few_shot_selection_criteria,
|
||||
"optimizer_llm_name": optimizer_llm_name,
|
||||
"num_generated_instructions_in_each_step": (
|
||||
num_generated_instructions_in_each_step
|
||||
),
|
||||
"evaluate_generated_ins_on_few_shot": evaluate_generated_ins_on_few_shot,
|
||||
"num_few_shot_questions_for_instruction_refinement": (
|
||||
num_few_shot_questions_for_instruction_refinement
|
||||
),
|
||||
"evaluate_old_ins_on_few_shot": evaluate_old_ins_on_few_shot,
|
||||
"eval_interval": eval_interval,
|
||||
"save_folder": save_folder,
|
||||
}
|
||||
print("=== 开始优化过程 ===")
|
||||
try:
|
||||
opt_utils.run_evolution(**evolution_kwargs)
|
||||
print("=== 优化完成 ===")
|
||||
except Exception as e:
|
||||
import traceback
|
||||
print(f"!!! 优化失败: {e} !!!", file=sys.stderr)
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(main)
|
||||
|
||||
@@ -1,424 +0,0 @@
|
||||
# Copyright 2023 The OPRO Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
r"""Optimize over the objective function of a linear regression problem.
|
||||
|
||||
Usage:
|
||||
|
||||
```
|
||||
python optimize_linear_regression.py --optimizer="text-bison"
|
||||
```
|
||||
|
||||
Note:
|
||||
- When using a Google-Cloud-served model (like text-bison at
|
||||
https://developers.generativeai.google/tutorials/text_quickstart), add
|
||||
`--palm_api_key="<your_key>"`
|
||||
- When using an OpenAI model, add `--openai_api_key="<your_key>"`
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import functools
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
|
||||
OPRO_ROOT_PATH = os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
)
|
||||
sys.path.insert(0, OPRO_ROOT_PATH)
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
import google.generativeai as palm
|
||||
import numpy as np
|
||||
import openai
|
||||
|
||||
from opro import prompt_utils
|
||||
|
||||
_OPENAI_API_KEY = flags.DEFINE_string(
|
||||
"openai_api_key", "", "The OpenAI API key."
|
||||
)
|
||||
|
||||
_PALM_API_KEY = flags.DEFINE_string("palm_api_key", "", "The PaLM API key.")
|
||||
|
||||
_OPTIMIZER = flags.DEFINE_string(
|
||||
"optimizer", "gpt-3.5-turbo", "The name of the optimizer LLM."
|
||||
)
|
||||
|
||||
|
||||
def main(_):
|
||||
# ============== set optimization experiment configurations ================
|
||||
num_points = 50 # number of points in linear regression
|
||||
w_true = 15 # the true w
|
||||
b_true = 14 # the true b
|
||||
max_num_steps = 500 # the number of optimization steps
|
||||
num_reps = 5 # the number of repeated runs
|
||||
max_num_pairs = 20 # the maximum number of input-output pairs in meta-prompt
|
||||
num_input_decimals = 0 # num of decimals for input values in meta-prompt
|
||||
num_output_decimals = 0 # num of decimals for output values in meta-prompt
|
||||
num_generated_points_in_each_step = 8
|
||||
|
||||
# ================ load LLM settings ===================
|
||||
optimizer_llm_name = _OPTIMIZER.value
|
||||
assert optimizer_llm_name in {
|
||||
"text-bison",
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-4",
|
||||
}
|
||||
openai_api_key = _OPENAI_API_KEY.value
|
||||
palm_api_key = _PALM_API_KEY.value
|
||||
|
||||
if optimizer_llm_name in {"gpt-3.5-turbo", "gpt-4"}:
|
||||
assert openai_api_key, "The OpenAI API key must be provided."
|
||||
openai.api_key = openai_api_key
|
||||
else:
|
||||
assert optimizer_llm_name == "text-bison"
|
||||
assert (
|
||||
palm_api_key
|
||||
), "A PaLM API key is needed when prompting the text-bison model."
|
||||
palm.configure(api_key=palm_api_key)
|
||||
|
||||
# =================== create the result directory ==========================
|
||||
datetime_str = (
|
||||
str(datetime.datetime.now().replace(microsecond=0))
|
||||
.replace(" ", "-")
|
||||
.replace(":", "-")
|
||||
)
|
||||
|
||||
save_folder = os.path.join(
|
||||
OPRO_ROOT_PATH,
|
||||
"outputs",
|
||||
"optimization-results",
|
||||
f"linear_regression-o-{optimizer_llm_name}-{datetime_str}/",
|
||||
)
|
||||
os.makedirs(save_folder)
|
||||
print(f"result directory:\n{save_folder}")
|
||||
|
||||
# ====================== optimizer model configs ============================
|
||||
if optimizer_llm_name.lower() == "text-bison":
|
||||
# when prompting text-bison with Cloud API
|
||||
optimizer_finetuned_palm_temperature = 1.0
|
||||
optimizer_finetuned_palm_max_decode_steps = 1024
|
||||
optimizer_finetuned_palm_batch_size = 1
|
||||
optimizer_finetuned_palm_num_servers = 1
|
||||
optimizer_finetuned_palm_dict = dict()
|
||||
optimizer_finetuned_palm_dict["temperature"] = (
|
||||
optimizer_finetuned_palm_temperature
|
||||
)
|
||||
optimizer_finetuned_palm_dict["batch_size"] = (
|
||||
optimizer_finetuned_palm_batch_size
|
||||
)
|
||||
optimizer_finetuned_palm_dict["num_servers"] = (
|
||||
optimizer_finetuned_palm_num_servers
|
||||
)
|
||||
optimizer_finetuned_palm_dict["max_decode_steps"] = (
|
||||
optimizer_finetuned_palm_max_decode_steps
|
||||
)
|
||||
|
||||
call_optimizer_finetuned_palm_server_func = functools.partial(
|
||||
prompt_utils.call_palm_server_from_cloud,
|
||||
# prompt_utils.call_vllm,
|
||||
model="text-bison-001",
|
||||
temperature=optimizer_finetuned_palm_dict["temperature"],
|
||||
max_decode_steps=optimizer_finetuned_palm_dict["max_decode_steps"],
|
||||
)
|
||||
|
||||
optimizer_llm_dict = {
|
||||
"model_type": optimizer_llm_name.lower(),
|
||||
}
|
||||
optimizer_llm_dict.update(optimizer_finetuned_palm_dict)
|
||||
call_optimizer_server_func = call_optimizer_finetuned_palm_server_func
|
||||
|
||||
else:
|
||||
assert optimizer_llm_name in {"gpt-3.5-turbo", "gpt-4"}
|
||||
optimizer_gpt_max_decode_steps = 1024
|
||||
optimizer_gpt_temperature = 1.0
|
||||
|
||||
optimizer_llm_dict = dict()
|
||||
optimizer_llm_dict["max_decode_steps"] = optimizer_gpt_max_decode_steps
|
||||
optimizer_llm_dict["temperature"] = optimizer_gpt_temperature
|
||||
optimizer_llm_dict["batch_size"] = 1
|
||||
call_optimizer_server_func = functools.partial(
|
||||
prompt_utils.call_openai_server_func,
|
||||
model=optimizer_llm_name,
|
||||
max_decode_steps=optimizer_gpt_max_decode_steps,
|
||||
temperature=optimizer_gpt_temperature,
|
||||
)
|
||||
|
||||
# ====================== try calling the servers ============================
|
||||
print("\n======== testing the optimizer server ===========")
|
||||
optimizer_test_output = call_optimizer_server_func(
|
||||
"Does the sun rise from the north? Just answer yes or no.",
|
||||
temperature=1.0,
|
||||
)
|
||||
print(f"optimizer test output: {optimizer_test_output}")
|
||||
print("Finished testing the optimizer server.")
|
||||
print("\n=================================================")
|
||||
|
||||
# ====================== utility functions ============================
|
||||
def evaluate_loss(X, y, w, b): # pylint: disable=invalid-name
|
||||
residual = y - (X * w + b)
|
||||
return np.linalg.norm(residual) ** 2
|
||||
|
||||
def gen_meta_prompt(
|
||||
old_value_pairs_set,
|
||||
X, # pylint: disable=invalid-name, unused-argument
|
||||
y, # pylint: disable=unused-argument
|
||||
num_input_decimals=5,
|
||||
num_output_decimals=5,
|
||||
max_num_pairs=100,
|
||||
):
|
||||
"""Generate the meta-prompt for optimization.
|
||||
|
||||
Args:
|
||||
old_value_pairs_set (set): the set of old (w, b, z) pairs.
|
||||
X (np.array): the 1D array of x values.
|
||||
y (np.array): the 1D array of y values.
|
||||
num_input_decimals (int): the number of decimals for (w, b) in the
|
||||
meta-prompt.
|
||||
num_output_decimals (int): the number of decimals for z in the meta-prompt.
|
||||
max_num_pairs (int): the maximum number of exemplars in the meta-prompt.
|
||||
|
||||
Returns:
|
||||
meta_prompt (str): the generated meta-prompt.
|
||||
"""
|
||||
old_value_pairs_set = set(
|
||||
[ # pylint: disable=g-complex-comprehension
|
||||
(
|
||||
np.round(w, num_input_decimals)
|
||||
if num_input_decimals > 0
|
||||
else int(w),
|
||||
np.round(b, num_input_decimals)
|
||||
if num_input_decimals > 0
|
||||
else int(b),
|
||||
np.round(z, num_output_decimals)
|
||||
if num_output_decimals > 0
|
||||
else int(z),
|
||||
)
|
||||
for w, b, z in old_value_pairs_set
|
||||
]
|
||||
)
|
||||
old_value_pairs = list(old_value_pairs_set)
|
||||
old_value_pairs = sorted(old_value_pairs, key=lambda x: -x[2])[
|
||||
-max_num_pairs:
|
||||
]
|
||||
old_value_pairs_substr = ""
|
||||
for w, b, z in old_value_pairs:
|
||||
old_value_pairs_substr += f"\ninput:\nw={w}, b={b}\nvalue:\n{z}\n"
|
||||
meta_prompt = """
|
||||
Now you will help me minimize a function with two input variables w, b. I have some (w, b) pairs and the function values at those points. The pairs are arranged in descending order based on their function values, where lower values are better.
|
||||
""".strip()
|
||||
meta_prompt += "\n\n"
|
||||
meta_prompt += old_value_pairs_substr.strip()
|
||||
meta_prompt += "\n\n"
|
||||
# function_analytic_form = ""
|
||||
# for xi, yi in zip(X, y):
|
||||
# function_analytic_form += f"({yi:.4f} - ({xi:.4f} * w + b)) ** 2 + "
|
||||
# function_analytic_form = function_analytic_form[:-3]
|
||||
# meta_prompt += (
|
||||
# "The function has the analytic form f(w, b) ="
|
||||
# f" {function_analytic_form}. When evaluating the value of a (w, b)"
|
||||
# " pair, you should replace the w and b in the analytic form with your"
|
||||
# " values and do the computation."
|
||||
# )
|
||||
# meta_prompt += "\n\n"
|
||||
meta_prompt += """Give me a new (w, b) pair that is different from all pairs above, and has a function value lower than any of the above. Do not write code. The output must end with a pair [w, b], where w and b are numerical values.
|
||||
""".strip()
|
||||
return meta_prompt
|
||||
|
||||
def extract_string_in_square_brackets(input_string):
|
||||
raw_result = re.findall(r"\[.*?\]", input_string)
|
||||
if raw_result:
|
||||
for pair in raw_result[::-1]:
|
||||
if "=" not in pair and ("w" in pair or "b" in pair):
|
||||
continue
|
||||
return pair[1:-1]
|
||||
return ""
|
||||
else:
|
||||
return ""
|
||||
|
||||
def parse_output(extracted_output):
|
||||
"""Parse the extracted output 'w, b' string to np.array([w, b]).
|
||||
|
||||
Args:
|
||||
extracted_output (str): the extracted output string, like '1.5, 2.5'.
|
||||
|
||||
Returns:
|
||||
parsed_output (np.array): the parsed output in a numpy array, like [1.5,
|
||||
2.5].
|
||||
"""
|
||||
if not extracted_output:
|
||||
return
|
||||
extracted_values = []
|
||||
for item in extracted_output.split(","):
|
||||
if "=" in item:
|
||||
item = item[item.index("=") + 1 :]
|
||||
extracted_values.append(item.strip())
|
||||
parsed_output = np.array(extracted_values).astype(float)
|
||||
return parsed_output
|
||||
|
||||
configs_dict = dict()
|
||||
results_dict = dict()
|
||||
num_convergence_steps = []
|
||||
for i_rep in range(num_reps):
|
||||
found_optimal = False
|
||||
print(f"\nRep {i_rep}:")
|
||||
|
||||
# ================= generate the ground truth X, y =====================
|
||||
X = np.arange(num_points).astype(float) + 1 # pylint: disable=invalid-name
|
||||
np.random.seed(i_rep + 1)
|
||||
y = X * w_true + b_true + np.random.randn(num_points)
|
||||
loss_at_true_values = evaluate_loss(X, y, w_true, b_true)
|
||||
print(f"value at (w_true, b_true): {loss_at_true_values}")
|
||||
|
||||
# ================= generate the starting points =====================
|
||||
num_starting_points = 5 # the number of initial points for optimization
|
||||
np.random.seed((i_rep + 1) * 10)
|
||||
init_w = np.random.uniform(low=10, high=20, size=num_starting_points)
|
||||
np.random.seed((i_rep + 1) * 100)
|
||||
init_b = np.random.uniform(low=10, high=20, size=num_starting_points)
|
||||
|
||||
# ====================== run optimization ============================
|
||||
configs_dict_single_rep = {
|
||||
"optimizer_llm_configs": optimizer_llm_dict,
|
||||
"data": {
|
||||
"num_points": num_points,
|
||||
"w_true": w_true,
|
||||
"b_true": b_true,
|
||||
"loss_at_true_values": loss_at_true_values,
|
||||
"X": list(X),
|
||||
"y": list(y),
|
||||
},
|
||||
"init_w": list(init_w),
|
||||
"init_b": list(init_b),
|
||||
"max_num_steps": max_num_steps,
|
||||
"max_num_pairs": max_num_pairs,
|
||||
"num_input_decimals": num_input_decimals,
|
||||
"num_output_decimals": num_output_decimals,
|
||||
"num_generated_points_in_each_step": num_generated_points_in_each_step,
|
||||
}
|
||||
configs_dict[i_rep] = configs_dict_single_rep
|
||||
configs_json_path = os.path.join(save_folder, "configs.json")
|
||||
print(f"saving configs to\n{configs_json_path}")
|
||||
with open(configs_json_path, "w") as f:
|
||||
json.dump(configs_dict, f, indent=4)
|
||||
|
||||
old_value_pairs_set = set()
|
||||
old_value_pairs_with_i_step = [] # format: [(w, b, z = f(w, b), i_step)]
|
||||
meta_prompts_dict = dict() # format: {i_step: meta_prompt}
|
||||
raw_outputs_dict = dict() # format: {i_step: raw_outputs}
|
||||
|
||||
rounded_inits = [
|
||||
(np.round(w, num_input_decimals), np.round(b, num_input_decimals))
|
||||
for w, b in zip(init_w, init_b)
|
||||
]
|
||||
rounded_inits = [
|
||||
tuple(item) for item in list(np.unique(rounded_inits, axis=0))
|
||||
]
|
||||
for w, b in rounded_inits:
|
||||
z = evaluate_loss(X, y, w, b)
|
||||
old_value_pairs_set.add((w, b, z))
|
||||
old_value_pairs_with_i_step.append((w, b, z, -1))
|
||||
|
||||
print("\n================ run optimization ==============")
|
||||
print(
|
||||
f"initial points: {[tuple(item[:2]) for item in old_value_pairs_set]}"
|
||||
)
|
||||
print(f"initial values: {[item[-1] for item in old_value_pairs_set]}")
|
||||
results_json_path = os.path.join(save_folder, "results.json")
|
||||
print(f"saving results to\n{results_json_path}")
|
||||
|
||||
for i_step in range(max_num_steps):
|
||||
print(f"\nStep {i_step}:")
|
||||
meta_prompt = gen_meta_prompt(
|
||||
old_value_pairs_set,
|
||||
X,
|
||||
y,
|
||||
num_input_decimals=num_input_decimals,
|
||||
num_output_decimals=num_output_decimals,
|
||||
max_num_pairs=max_num_pairs,
|
||||
)
|
||||
if not i_step % 5:
|
||||
print("\n=================================================")
|
||||
print(f"meta_prompt:\n{meta_prompt}")
|
||||
meta_prompts_dict[i_step] = meta_prompt
|
||||
|
||||
# generate a maximum of the given number of points in each step
|
||||
remaining_num_points_to_generate = num_generated_points_in_each_step
|
||||
raw_outputs = []
|
||||
while remaining_num_points_to_generate > 0:
|
||||
raw_outputs += call_optimizer_server_func(meta_prompt)
|
||||
remaining_num_points_to_generate -= optimizer_llm_dict["batch_size"]
|
||||
raw_outputs = raw_outputs[:num_generated_points_in_each_step]
|
||||
|
||||
raw_outputs_dict[i_step] = raw_outputs
|
||||
parsed_outputs = []
|
||||
for string in raw_outputs:
|
||||
if not i_step % 5:
|
||||
print("\n=================================================")
|
||||
print("raw output:\n", string)
|
||||
print("\n=================================================")
|
||||
try:
|
||||
parsed_output = parse_output(
|
||||
extract_string_in_square_brackets(string)
|
||||
)
|
||||
if parsed_output is not None and len(parsed_output) == 2:
|
||||
parsed_outputs.append(parsed_output)
|
||||
except ValueError:
|
||||
pass
|
||||
parsed_outputs = [tuple(item) for item in parsed_outputs]
|
||||
print(f"proposed points before rounding: {parsed_outputs}")
|
||||
|
||||
# round the proposed points to the number of decimals in meta-prompt
|
||||
rounded_outputs = [
|
||||
(np.round(w, num_input_decimals), np.round(b, num_input_decimals))
|
||||
for w, b in parsed_outputs
|
||||
]
|
||||
rounded_outputs = [
|
||||
tuple(item) for item in list(np.unique(rounded_outputs, axis=0))
|
||||
]
|
||||
print(f"proposed points after rounding: {rounded_outputs}")
|
||||
|
||||
# evaluate the values of proposed and rounded outputs
|
||||
single_step_values = []
|
||||
for w, b in rounded_outputs:
|
||||
if w == w_true and b == b_true:
|
||||
found_optimal = True
|
||||
z = evaluate_loss(X, y, w, b)
|
||||
single_step_values.append(z)
|
||||
old_value_pairs_set.add((w, b, z))
|
||||
old_value_pairs_with_i_step.append((w, b, z, i_step))
|
||||
print(f"single_step_values: {single_step_values}")
|
||||
|
||||
# ====================== save results ============================
|
||||
results_dict_single_rep = {
|
||||
"meta_prompts": meta_prompts_dict,
|
||||
"raw_outputs": raw_outputs_dict,
|
||||
"old_value_pairs_with_i_step": old_value_pairs_with_i_step,
|
||||
}
|
||||
results_dict[i_rep] = results_dict_single_rep
|
||||
with open(results_json_path, "w") as f:
|
||||
json.dump(results_dict, f, indent=4)
|
||||
if found_optimal:
|
||||
print(
|
||||
f"Repetition {i_rep+1}, optimal found at Step {i_step+1}, saving"
|
||||
f" final results to\n{save_folder}"
|
||||
)
|
||||
num_convergence_steps.append(i_step + 1)
|
||||
break
|
||||
print(f"num_convergence_steps: {num_convergence_steps}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(main)
|
||||
@@ -1,430 +0,0 @@
|
||||
# Copyright 2024 The OPRO Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
r"""Optimize over the objective function of a traveling salesman problem.
|
||||
|
||||
Usage:
|
||||
|
||||
```
|
||||
python optimize_tsp.py --optimizer="text-bison"
|
||||
```
|
||||
|
||||
Note:
|
||||
- When using a Google-Cloud-served model (like text-bison at
|
||||
https://developers.generativeai.google/tutorials/text_quickstart), add
|
||||
`--palm_api_key="<your_key>"`
|
||||
- When using an OpenAI model, add `--openai_api_key="<your_key>"`
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import functools
|
||||
import getpass
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import itertools
|
||||
|
||||
OPRO_ROOT_PATH = os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
)
|
||||
sys.path.insert(0, OPRO_ROOT_PATH)
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
import google.generativeai as palm
|
||||
import numpy as np
|
||||
import openai
|
||||
|
||||
from opro import prompt_utils
|
||||
|
||||
|
||||
_OPENAI_API_KEY = flags.DEFINE_string(
|
||||
"openai_api_key", "", "The OpenAI API key."
|
||||
)
|
||||
|
||||
_PALM_API_KEY = flags.DEFINE_string("palm_api_key", "", "The PaLM API key.")
|
||||
|
||||
_OPTIMIZER = flags.DEFINE_string(
|
||||
"optimizer", "gpt-3.5-turbo", "The name of the optimizer LLM."
|
||||
)
|
||||
|
||||
_START_ALGORITHM = flags.DEFINE_string(
|
||||
"starting_algorithm", "farthest_insertion", "The name of the starting algorithm. Select from [dp, nearest_neighbor, farthest_insertion]"
|
||||
)
|
||||
|
||||
def main(_):
|
||||
# ============== set optimization experiment configurations ================
|
||||
num_points = 100 # number of points in TSP
|
||||
num_steps = 500 # the number of optimization steps
|
||||
max_num_pairs = 10 # the maximum number of input-output pairs in meta-prompt
|
||||
num_decimals = 0 # num of decimals for distances in meta-prompt
|
||||
num_starting_points = 5 # the number of initial points for optimization
|
||||
num_decode_per_step = 8 # the number of decoded solutions per step
|
||||
|
||||
# ================ load LLM settings ===================
|
||||
optimizer_llm_name = _OPTIMIZER.value
|
||||
assert optimizer_llm_name in {
|
||||
"text-bison",
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-4",
|
||||
}
|
||||
openai_api_key = _OPENAI_API_KEY.value
|
||||
palm_api_key = _PALM_API_KEY.value
|
||||
|
||||
if optimizer_llm_name in {"gpt-3.5-turbo", "gpt-4"}:
|
||||
assert openai_api_key, "The OpenAI API key must be provided."
|
||||
openai.api_key = openai_api_key
|
||||
else:
|
||||
assert optimizer_llm_name == "text-bison"
|
||||
assert (
|
||||
palm_api_key
|
||||
), "A PaLM API key is needed when prompting the text-bison model."
|
||||
palm.configure(api_key=palm_api_key)
|
||||
|
||||
# =================== create the result directory ==========================
|
||||
datetime_str = (
|
||||
str(datetime.datetime.now().replace(microsecond=0))
|
||||
.replace(" ", "-")
|
||||
.replace(":", "-")
|
||||
)
|
||||
|
||||
save_folder = os.path.join(
|
||||
OPRO_ROOT_PATH,
|
||||
"outputs",
|
||||
"optimization-results",
|
||||
f"tsp-o-{optimizer_llm_name}-{datetime_str}/",
|
||||
)
|
||||
os.makedirs(save_folder)
|
||||
print(f"result directory:\n{save_folder}")
|
||||
|
||||
# ====================== optimizer model configs ============================
|
||||
if optimizer_llm_name.lower() == "text-bison":
|
||||
# when prompting text-bison with Cloud API
|
||||
optimizer_finetuned_palm_temperature = 1.0
|
||||
optimizer_finetuned_palm_max_decode_steps = 1024
|
||||
optimizer_finetuned_palm_batch_size = 1
|
||||
optimizer_finetuned_palm_num_servers = 1
|
||||
optimizer_finetuned_palm_dict = dict()
|
||||
optimizer_finetuned_palm_dict["temperature"] = (
|
||||
optimizer_finetuned_palm_temperature
|
||||
)
|
||||
optimizer_finetuned_palm_dict["batch_size"] = (
|
||||
optimizer_finetuned_palm_batch_size
|
||||
)
|
||||
optimizer_finetuned_palm_dict["num_servers"] = (
|
||||
optimizer_finetuned_palm_num_servers
|
||||
)
|
||||
optimizer_finetuned_palm_dict["max_decode_steps"] = (
|
||||
optimizer_finetuned_palm_max_decode_steps
|
||||
)
|
||||
|
||||
call_optimizer_finetuned_palm_server_func = functools.partial(
|
||||
prompt_utils.call_palm_server_from_cloud,
|
||||
# prompt_utils.call_vllm,
|
||||
model="text-bison-001",
|
||||
temperature=optimizer_finetuned_palm_dict["temperature"],
|
||||
max_decode_steps=optimizer_finetuned_palm_dict["max_decode_steps"],
|
||||
)
|
||||
|
||||
optimizer_llm_dict = {
|
||||
"model_type": optimizer_llm_name.lower(),
|
||||
}
|
||||
optimizer_llm_dict.update(optimizer_finetuned_palm_dict)
|
||||
call_optimizer_server_func = call_optimizer_finetuned_palm_server_func
|
||||
|
||||
else:
|
||||
assert optimizer_llm_name in {"gpt-3.5-turbo", "gpt-4"}
|
||||
optimizer_gpt_max_decode_steps = 1024
|
||||
optimizer_gpt_temperature = 1.0
|
||||
|
||||
optimizer_llm_dict = dict()
|
||||
optimizer_llm_dict["max_decode_steps"] = optimizer_gpt_max_decode_steps
|
||||
optimizer_llm_dict["temperature"] = optimizer_gpt_temperature
|
||||
optimizer_llm_dict["batch_size"] = 1
|
||||
call_optimizer_server_func = functools.partial(
|
||||
prompt_utils.call_openai_server_func,
|
||||
model=optimizer_llm_name,
|
||||
max_decode_steps=optimizer_gpt_max_decode_steps,
|
||||
temperature=optimizer_gpt_temperature,
|
||||
)
|
||||
|
||||
# ====================== try calling the servers ============================
|
||||
print("\n======== testing the optimizer server ===========")
|
||||
optimizer_test_output = call_optimizer_server_func(
|
||||
"Does the sun rise from the north? Just answer yes or no.",
|
||||
temperature=1.0,
|
||||
)
|
||||
print(f"optimizer test output: {optimizer_test_output}")
|
||||
print("Finished testing the optimizer server.")
|
||||
print("\n=================================================")
|
||||
|
||||
# ====================== utility functions ============================
|
||||
def evaluate_distance(x, y, trace, num_decimals): # pylint: disable=invalid-name
|
||||
dis = 0
|
||||
try:
|
||||
for i in range(len(trace) - 1):
|
||||
id0 = trace[i]
|
||||
id1 = trace[i + 1]
|
||||
dis += np.sqrt((x[id0] - x[id1]) ** 2 + (y[id0] - y[id1]) ** 2)
|
||||
except:
|
||||
return -1
|
||||
id0 = trace[-1]
|
||||
id1 = trace[0]
|
||||
dis += np.sqrt((x[id0] - x[id1]) ** 2 + (y[id0] - y[id1]) ** 2)
|
||||
dis = np.round(dis, num_decimals) if num_decimals > 0 else int(dis)
|
||||
return dis
|
||||
|
||||
def solve_tsp(x, y, num_points, num_decimals, starting_algorithm):
|
||||
if starting_algorithm == "nearest_neighbor":
|
||||
min_dis = 0
|
||||
gt_sol = [0]
|
||||
remaining_points = list(range(1, num_points))
|
||||
while len(remaining_points) > 0:
|
||||
min_p = -1
|
||||
min_cur_dis = -1
|
||||
for p in remaining_points:
|
||||
cur_dis = np.sqrt((x[p] - x[gt_sol[-1]]) ** 2 + (y[p] - y[gt_sol[-1]]) ** 2)
|
||||
if min_p == -1 or cur_dis < min_cur_dis:
|
||||
min_p = p
|
||||
min_cur_dis = cur_dis
|
||||
gt_sol.append(min_p)
|
||||
min_dis += min_cur_dis
|
||||
remaining_points.remove(min_p)
|
||||
min_dis += np.sqrt((x[0] - x[gt_sol[-1]]) ** 2 + (y[0] - y[gt_sol[-1]]) ** 2)
|
||||
min_dis = np.round(min_dis, num_decimals) if num_decimals > 0 else int(min_dis)
|
||||
return gt_sol, min_dis
|
||||
elif starting_algorithm == "farthest_insertion":
|
||||
gt_sol = [0]
|
||||
remaining_points = list(range(1, num_points))
|
||||
while len(remaining_points) > 0:
|
||||
max_p = -1
|
||||
max_cur_dis = -1
|
||||
max_cur_index = -1
|
||||
for p in remaining_points:
|
||||
min_cur_dis = -1
|
||||
min_cur_index = -1
|
||||
for index in range(1, len(gt_sol) + 1):
|
||||
new_sol = gt_sol[:index] + [p] + gt_sol[index:]
|
||||
cur_dis = evaluate_distance(x, y, new_sol, num_decimals)
|
||||
if min_cur_dis == -1 or cur_dis < min_cur_dis:
|
||||
min_cur_dis = cur_dis
|
||||
min_cur_index = index
|
||||
if max_cur_dis == -1 or min_cur_dis > max_cur_dis:
|
||||
max_p = p
|
||||
max_cur_dis = min_cur_dis
|
||||
max_cur_index = min_cur_index
|
||||
gt_sol = gt_sol[:max_cur_index] + [max_p] + gt_sol[max_cur_index:]
|
||||
remaining_points.remove(max_p)
|
||||
min_dis = evaluate_distance(x, y, gt_sol, num_decimals)
|
||||
return gt_sol, min_dis
|
||||
|
||||
f = {(0, 1): (0, [0])}
|
||||
q = [(0, 1)]
|
||||
min_dis = -1
|
||||
gt_sol = list(range(num_points))
|
||||
while len(q) > 0:
|
||||
p, status = q[0]
|
||||
q = q[1:]
|
||||
for i in range(num_points):
|
||||
if 2 << i >> 1 & status == 0:
|
||||
new_status = status + (2 << i >> 1)
|
||||
new_dis = f[(p, status)][0] + np.sqrt((x[i] - x[p]) ** 2 + (y[i] - y[p]) ** 2)
|
||||
if (i, new_status) not in f or new_dis < f[(i, new_status)][0]:
|
||||
f[(i, new_status)] = (new_dis, f[(p, status)][1] + [i])
|
||||
if new_status == (2 << num_points >> 1) - 1:
|
||||
new_dis += np.sqrt((x[i] - x[0]) ** 2 + (y[i] - y[0]) ** 2)
|
||||
if min_dis == -1 or new_dis < min_dis:
|
||||
min_dis = new_dis
|
||||
gt_sol = f[(i, new_status)][1][:]
|
||||
elif (i, new_status) not in q:
|
||||
q.append((i, new_status))
|
||||
min_dis = np.round(min_dis, num_decimals) if num_decimals > 0 else int(min_dis)
|
||||
return gt_sol, min_dis
|
||||
|
||||
def gen_meta_prompt(
|
||||
old_value_pairs_set,
|
||||
x, # pylint: disable=invalid-name
|
||||
y,
|
||||
max_num_pairs=100,
|
||||
):
|
||||
"""Generate the meta-prompt for optimization.
|
||||
|
||||
Args:
|
||||
old_value_pairs_set (set): the set of old traces.
|
||||
X (np.array): the 1D array of x values.
|
||||
y (np.array): the 1D array of y values.
|
||||
num_decimals (int): the number of decimals in the
|
||||
meta-prompt.
|
||||
max_num_pairs (int): the maximum number of exemplars in the meta-prompt.
|
||||
|
||||
Returns:
|
||||
meta_prompt (str): the generated meta-prompt.
|
||||
"""
|
||||
old_value_pairs = list(old_value_pairs_set)
|
||||
old_value_pairs = sorted(old_value_pairs, key=lambda x: -x[1])[
|
||||
-max_num_pairs:
|
||||
]
|
||||
old_value_pairs_substr = ""
|
||||
for trace, dis in old_value_pairs:
|
||||
old_value_pairs_substr += f"\n<trace> {trace} </trace>\nlength:\n{dis}\n"
|
||||
meta_prompt = "You are given a list of points with coordinates below:\n"
|
||||
for i, (xi, yi) in enumerate(zip(x, y)):
|
||||
if i:
|
||||
meta_prompt += ", "
|
||||
meta_prompt += f"({i}): ({xi}, {yi})"
|
||||
meta_prompt += ".\n\nBelow are some previous traces and their lengths. The traces are arranged in descending order based on their lengths, where lower values are better.".strip()
|
||||
meta_prompt += "\n\n"
|
||||
meta_prompt += old_value_pairs_substr.strip()
|
||||
meta_prompt += "\n\n"
|
||||
meta_prompt += """Give me a new trace that is different from all traces above, and has a length lower than any of the above. The trace should traverse all points exactly once. The trace should start with '<trace>' and end with </trace>.
|
||||
""".strip()
|
||||
return meta_prompt
|
||||
|
||||
def extract_string(input_string):
|
||||
start_string = "<trace>"
|
||||
end_string = "</trace>"
|
||||
if start_string not in input_string:
|
||||
return ""
|
||||
input_string = input_string[input_string.index(start_string) + len(start_string):]
|
||||
if end_string not in input_string:
|
||||
return ""
|
||||
input_string = input_string[:input_string.index(end_string)]
|
||||
parsed_list = []
|
||||
for p in input_string.split(","):
|
||||
p = p.strip()
|
||||
try:
|
||||
p = int(p)
|
||||
except:
|
||||
continue
|
||||
parsed_list.append(p)
|
||||
return parsed_list
|
||||
|
||||
# ================= generate the ground truth trace =====================
|
||||
|
||||
x = np.random.uniform(low=-100, high=100, size=num_points)
|
||||
y = np.random.uniform(low=-100, high=100, size=num_points)
|
||||
x = [np.round(xi, num_decimals) if num_decimals > 0 else int(xi) for xi in x]
|
||||
y = [np.round(yi, num_decimals) if num_decimals > 0 else int(yi) for yi in y]
|
||||
|
||||
starting_algorithm = _START_ALGORITHM.value
|
||||
|
||||
gt_sol, min_dis = solve_tsp(x, y, num_points, num_decimals, starting_algorithm)
|
||||
print("ground truth solution" + str(gt_sol))
|
||||
print("min distance: ", min_dis)
|
||||
gt_sol_str = ",".join([str(i) for i in gt_sol])
|
||||
point_list = range(num_points)
|
||||
init_sols = []
|
||||
while len(init_sols) < num_starting_points:
|
||||
sol = np.random.permutation(point_list)
|
||||
if sol[0] != 0:
|
||||
continue
|
||||
sol_str = ",".join([str(i) for i in sol])
|
||||
if sol_str == gt_sol_str:
|
||||
continue
|
||||
init_sols.append(list(sol))
|
||||
|
||||
# ====================== run optimization ============================
|
||||
configs_dict = {
|
||||
"num_starting_points": num_starting_points,
|
||||
"num_decode_per_step": num_decode_per_step,
|
||||
"optimizer_llm_configs": optimizer_llm_dict,
|
||||
"data": {
|
||||
"ground truth solution": [",".join([str(i) for i in gt_sol])],
|
||||
"loss_at_true_values": min_dis,
|
||||
"x": list(x),
|
||||
"y": list(y),
|
||||
},
|
||||
"init_sols": [",".join([str(i) for i in sol]) for sol in init_sols],
|
||||
"num_steps": num_steps,
|
||||
"max_num_pairs": max_num_pairs,
|
||||
"num_decimals": num_decimals,
|
||||
}
|
||||
configs_json_path = os.path.join(save_folder, "configs.json")
|
||||
print(f"saving configs to\n{configs_json_path}")
|
||||
with open(configs_json_path, "w") as f:
|
||||
json.dump(configs_dict, f, indent=4)
|
||||
|
||||
old_value_pairs_set = set()
|
||||
old_value_pairs_with_i_step = [] # format: [(trace, dis = f(trace), i_step)]
|
||||
meta_prompts_dict = dict() # format: {i_step: meta_prompt}
|
||||
raw_outputs_dict = dict() # format: {i_step: raw_outputs}
|
||||
|
||||
for sol in init_sols:
|
||||
dis = evaluate_distance(x, y, sol, num_decimals)
|
||||
sol_str = ",".join([str(i) for i in sol])
|
||||
old_value_pairs_set.add((sol_str, dis))
|
||||
old_value_pairs_with_i_step.append((sol_str, dis, -1))
|
||||
|
||||
print("\n================ run optimization ==============")
|
||||
print(f"initial points: {[tuple(item[:-1]) for item in old_value_pairs_set]}")
|
||||
print(f"initial values: {[item[-1] for item in old_value_pairs_set]}")
|
||||
results_json_path = os.path.join(save_folder, "results.json")
|
||||
print(f"saving results to\n{results_json_path}")
|
||||
|
||||
for i_step in range(num_steps):
|
||||
print(f"\nStep {i_step}:")
|
||||
meta_prompt = gen_meta_prompt(
|
||||
old_value_pairs_set,
|
||||
x,
|
||||
y,
|
||||
max_num_pairs=max_num_pairs,
|
||||
)
|
||||
print("\n=================================================")
|
||||
print(f"meta_prompt:\n{meta_prompt}")
|
||||
meta_prompts_dict[i_step] = meta_prompt
|
||||
raw_outputs = []
|
||||
parsed_outputs = []
|
||||
while len(parsed_outputs) < num_decode_per_step:
|
||||
raw_output = call_optimizer_server_func(meta_prompt)
|
||||
for string in raw_output:
|
||||
print("\n=================================================")
|
||||
print("raw output:\n", string)
|
||||
try:
|
||||
parsed_output = extract_string(string)
|
||||
if parsed_output is not None and len(set(parsed_output)) == num_points and len(parsed_output) == num_points and parsed_output[0] == 0:
|
||||
dis = evaluate_distance(x, y, parsed_output, num_decimals)
|
||||
if dis == -1:
|
||||
continue
|
||||
parsed_outputs.append(parsed_output)
|
||||
raw_outputs.append(string)
|
||||
except:
|
||||
pass
|
||||
print("\n=================================================")
|
||||
print(f"proposed points: {parsed_outputs}")
|
||||
raw_outputs_dict[i_step] = raw_outputs
|
||||
|
||||
# evaluate the values of proposed and rounded outputs
|
||||
single_step_values = []
|
||||
for trace in parsed_outputs:
|
||||
dis = evaluate_distance(x, y, trace, num_decimals)
|
||||
single_step_values.append(dis)
|
||||
trace_str = ",".join([str(i) for i in trace])
|
||||
old_value_pairs_set.add((trace_str, dis))
|
||||
old_value_pairs_with_i_step.append((trace_str, dis, i_step))
|
||||
print(f"single_step_values: {single_step_values}")
|
||||
print("ground truth solution" + str(gt_sol))
|
||||
print("min distance: ", min_dis)
|
||||
|
||||
# ====================== save results ============================
|
||||
results_dict = {
|
||||
"meta_prompts": meta_prompts_dict,
|
||||
"raw_outputs": raw_outputs_dict,
|
||||
"old_value_pairs_with_i_step": old_value_pairs_with_i_step,
|
||||
}
|
||||
with open(results_json_path, "w") as f:
|
||||
json.dump(results_dict, f, indent=4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(main)
|
||||
@@ -1,967 +0,0 @@
|
||||
# Copyright 2023 The OPRO Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
r"""The .py file for prompt optimization.
|
||||
|
||||
Usage:
|
||||
|
||||
Step 1: edit the starting instructions by modifying `initial_instructions`
|
||||
|
||||
Step 2: edit the training ratio by modifying `train_ratio`
|
||||
|
||||
Step 3: check if the model configs (like batch size) are the same as the actual serving configs
|
||||
|
||||
Step 4: run
|
||||
|
||||
```
|
||||
python optimize_instructions.py \
|
||||
--optimizer="gpt-3.5-turbo" --scorer="text-bison" \
|
||||
--instruction_pos="A_begin" --dataset="gsm8k" --task="train"
|
||||
```
|
||||
|
||||
The outputs will then be written to `outputs/optimization-results/` in the opro folder.
|
||||
|
||||
Notes:
|
||||
|
||||
1. One or more API keys may need to be provided:
|
||||
- When using a Google-Cloud-served model (like text-bison at https://developers.generativeai.google/tutorials/text_quickstart), add `--palm_api_key=<your_key>`
|
||||
- When using an OpenAI model, add `--openai_api_key=”<your_key>”`
|
||||
|
||||
2. The initial instructions should be provided in the "initial_instructions"
|
||||
variable.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import functools
|
||||
import os
|
||||
import sys
|
||||
|
||||
OPRO_ROOT_PATH = os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
)
|
||||
sys.path.insert(0, OPRO_ROOT_PATH)
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
import google.generativeai as palm
|
||||
import numpy as np
|
||||
import openai
|
||||
from opro import prompt_utils
|
||||
from opro.optimization import opt_utils
|
||||
import pandas as pd
|
||||
|
||||
ROOT_DATA_FOLDER_PATH = os.path.join(OPRO_ROOT_PATH, "data")
|
||||
|
||||
_LOCAL_MODEL_PATH = flags.DEFINE_string("local_model_path", None, "Path to local vLLM model.")
|
||||
|
||||
|
||||
_OPENAI_API_KEY = flags.DEFINE_string(
|
||||
"openai_api_key", "", "The OpenAI API key."
|
||||
)
|
||||
|
||||
_PALM_API_KEY = flags.DEFINE_string("palm_api_key", "", "The PaLM API key.")
|
||||
|
||||
_SCORER = flags.DEFINE_string(
|
||||
"scorer", "text-bison", "The name of the scorer LLM."
|
||||
)
|
||||
|
||||
_OPTIMIZER = flags.DEFINE_string(
|
||||
"optimizer", "gpt-3.5-turbo", "The name of the optimizer LLM."
|
||||
)
|
||||
|
||||
_DATASET = flags.DEFINE_string(
|
||||
"dataset", "gsm8k", "The name of dataset to search for instructions on."
|
||||
)
|
||||
|
||||
_TASK = flags.DEFINE_string(
|
||||
"task",
|
||||
"train",
|
||||
"The name of task within the above dataset to search for instructions on.",
|
||||
)
|
||||
|
||||
_INSTRUCTION_POS = flags.DEFINE_string(
|
||||
"instruction_pos",
|
||||
"A_begin",
|
||||
"The position of the instruction to search for.",
|
||||
)
|
||||
|
||||
_META_PROMPT_TYPE = flags.DEFINE_string(
|
||||
"meta_prompt_type",
|
||||
"both_instructions_and_exemplars",
|
||||
"The type of meta-prompt: whether to have both previous instructions and"
|
||||
" dataset exemplars (often for fine-tuned optimizers), or to have only"
|
||||
" previous instructions (often for pre-trained optimizers).",
|
||||
)
|
||||
|
||||
|
||||
def main(_):
|
||||
openai_api_key = _OPENAI_API_KEY.value
|
||||
palm_api_key = _PALM_API_KEY.value
|
||||
scorer_llm_name = _SCORER.value
|
||||
optimizer_llm_name = _OPTIMIZER.value
|
||||
dataset_name = _DATASET.value.lower()
|
||||
task_name = _TASK.value
|
||||
meta_prompt_type = _META_PROMPT_TYPE.value
|
||||
local_model_path = _LOCAL_MODEL_PATH.value
|
||||
|
||||
assert dataset_name in {
|
||||
"mmlu",
|
||||
"bbh",
|
||||
"gsm8k",
|
||||
}, "The lower-case dataset name must be one of mmlu, bbh, or gsm8k."
|
||||
if dataset_name == "mmlu":
|
||||
assert task_name in {
|
||||
"STEM",
|
||||
"humanities",
|
||||
"social sciences",
|
||||
"otheran (business, health, misc.)",
|
||||
} # for now only support searching on one MMLU category
|
||||
elif dataset_name == "bbh":
|
||||
assert task_name in {
|
||||
"boolean_expressions",
|
||||
"causal_judgement",
|
||||
"date_understanding",
|
||||
"disambiguation_qa",
|
||||
"dyck_languages",
|
||||
"formal_fallacies",
|
||||
"geometric_shapes",
|
||||
"hyperbaton",
|
||||
"logical_deduction_five_objects",
|
||||
"logical_deduction_seven_objects",
|
||||
"logical_deduction_three_objects",
|
||||
"movie_recommendation",
|
||||
"multistep_arithmetic_two",
|
||||
"navigate",
|
||||
"object_counting",
|
||||
"penguins_in_a_table",
|
||||
"reasoning_about_colored_objects",
|
||||
"ruin_names",
|
||||
"salient_translation_error_detection",
|
||||
"snarks",
|
||||
"sports_understanding",
|
||||
"temporal_sequences",
|
||||
"tracking_shuffled_objects_five_objects",
|
||||
"tracking_shuffled_objects_seven_objects",
|
||||
"tracking_shuffled_objects_three_objects",
|
||||
"web_of_lies",
|
||||
"word_sorting",
|
||||
}
|
||||
else:
|
||||
assert dataset_name == "gsm8k"
|
||||
assert task_name in {"train", "test"}
|
||||
|
||||
assert scorer_llm_name in {
|
||||
"text-bison",
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-4",
|
||||
"local",
|
||||
}
|
||||
assert optimizer_llm_name in {
|
||||
"text-bison",
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-4",
|
||||
"local",
|
||||
}
|
||||
assert meta_prompt_type in {
|
||||
"both_instructions_and_exemplars",
|
||||
"instructions_only",
|
||||
}
|
||||
|
||||
instruction_pos = _INSTRUCTION_POS.value
|
||||
assert instruction_pos in {
|
||||
"before_Q",
|
||||
"Q_begin",
|
||||
"Q_end",
|
||||
"A_begin",
|
||||
}, (
|
||||
"The instruction position should be either before the question, or at the"
|
||||
" beginning of the question, at the end of the question, or at the"
|
||||
" beginning of the answer."
|
||||
)
|
||||
print(
|
||||
f"scorer: {scorer_llm_name}, optimizer: {optimizer_llm_name}, dataset:"
|
||||
f" {dataset_name}, task: {task_name}, instruction_pos: {instruction_pos}"
|
||||
)
|
||||
|
||||
# make sure the scorer and optimizer models are callable
|
||||
|
||||
if scorer_llm_name in {"gpt-3.5-turbo", "gpt-4"}:
|
||||
assert openai_api_key, "The OpenAI API key must be provided."
|
||||
openai.api_key = openai_api_key
|
||||
elif scorer_llm_name == "text-bison":
|
||||
assert scorer_llm_name == "text-bison"
|
||||
assert (
|
||||
palm_api_key
|
||||
), "A PaLM API key is needed when prompting the text-bison model."
|
||||
palm.configure(api_key=palm_api_key)
|
||||
|
||||
elif scorer_llm_name == "local":
|
||||
assert local_model_path, "The local model path must be provided."
|
||||
assert os.path.exists(local_model_path), (
|
||||
f"The local model path {local_model_path} does not exist."
|
||||
)
|
||||
# set the local model path for vLLM
|
||||
# prompt_utils.call_local_server_func(local_model_path)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown scorer_llm_name: {scorer_llm_name}. "
|
||||
"It should be one of text-bison, gpt-3.5-turbo, gpt-4, or local."
|
||||
)
|
||||
|
||||
if optimizer_llm_name in {"gpt-3.5-turbo", "gpt-4"}:
|
||||
assert openai_api_key, "The OpenAI API key must be provided."
|
||||
openai.api_key = openai_api_key
|
||||
elif optimizer_llm_name == "text-bison":
|
||||
assert optimizer_llm_name == "text-bison"
|
||||
assert (
|
||||
palm_api_key
|
||||
), "A PaLM API key is needed when prompting the text-bison model."
|
||||
palm.configure(api_key=palm_api_key)
|
||||
|
||||
elif optimizer_llm_name == "local":
|
||||
assert local_model_path, "The local model path must be provided."
|
||||
assert os.path.exists(local_model_path), (
|
||||
f"The local model path {local_model_path} does not exist."
|
||||
)
|
||||
# set the local model path for vLLM
|
||||
# prompt_utils.call_local_server_func(local_model_path)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown scorer_llm_name: {optimizer_llm_name}. "
|
||||
"It should be one of text-bison, gpt-3.5-turbo, gpt-4, or local."
|
||||
)
|
||||
|
||||
|
||||
if dataset_name == "mmlu":
|
||||
root_data_folder_path = os.path.join(ROOT_DATA_FOLDER_PATH, "MMLU-data")
|
||||
elif dataset_name == "bbh":
|
||||
root_data_folder_path = os.path.join(
|
||||
ROOT_DATA_FOLDER_PATH, "BIG-Bench-Hard-data/"
|
||||
)
|
||||
else:
|
||||
assert dataset_name == "gsm8k"
|
||||
root_data_folder_path = os.path.join(ROOT_DATA_FOLDER_PATH, "gsm_data")
|
||||
|
||||
# =================== create the result directory ==========================
|
||||
datetime_str = (
|
||||
str(datetime.datetime.now().replace(microsecond=0))
|
||||
.replace(" ", "-")
|
||||
.replace(":", "-")
|
||||
)
|
||||
|
||||
save_folder = os.path.join(
|
||||
OPRO_ROOT_PATH,
|
||||
"outputs",
|
||||
"optimization-results",
|
||||
f"{dataset_name.upper()}-{task_name}-s-{scorer_llm_name}-o-{optimizer_llm_name}-{datetime_str}/",
|
||||
)
|
||||
result_by_instruction_folder = os.path.join(
|
||||
save_folder, "result_by_instruction"
|
||||
)
|
||||
os.makedirs(result_by_instruction_folder)
|
||||
print(f"result directory:\n{save_folder}")
|
||||
|
||||
# ====================== scorer model configs ==============================
|
||||
# difference between num_decodes and batch_size:
|
||||
# - num_decodes: how many outputs we actually want for each input
|
||||
# - batch_size: the batch size in model serving, should equal to that in
|
||||
# model serving config
|
||||
# 常量定义
|
||||
DEFAULT_MAX_TOKENS = 1024
|
||||
DEFAULT_TEMPERATURE = 0.0
|
||||
PALM_MODEL_NAME = "text-bison-001"
|
||||
|
||||
if scorer_llm_name == "text-bison":
|
||||
config = {
|
||||
"temperature": DEFAULT_TEMPERATURE,
|
||||
"max_decode_steps": DEFAULT_MAX_TOKENS,
|
||||
"batch_size": 1,
|
||||
"num_servers": 1,
|
||||
}
|
||||
call_scorer_server_func = functools.partial(
|
||||
prompt_utils.call_palm_server_from_cloud,
|
||||
model=PALM_MODEL_NAME,
|
||||
**config
|
||||
)
|
||||
scorer_llm_dict = {"model_type": "text-bison", **config}
|
||||
|
||||
elif scorer_llm_name in {"gpt-3.5-turbo", "gpt-4"}:
|
||||
config = {
|
||||
"temperature": DEFAULT_TEMPERATURE,
|
||||
"max_decode_steps": DEFAULT_MAX_TOKENS,
|
||||
"batch_size": 1,
|
||||
"num_servers": 1,
|
||||
}
|
||||
call_scorer_server_func = functools.partial(
|
||||
prompt_utils.call_openai_server_func,
|
||||
model=scorer_llm_name.lower(),
|
||||
**config
|
||||
)
|
||||
scorer_llm_dict = {"model_type": scorer_llm_name.lower(), **config}
|
||||
|
||||
elif scorer_llm_name == "local":
|
||||
print(f"[DEBUG] local_model_path: {local_model_path}")
|
||||
assert local_model_path, "Local model path must be provided."
|
||||
config = {
|
||||
"temperature": DEFAULT_TEMPERATURE,
|
||||
"max_decode_steps": DEFAULT_MAX_TOKENS,
|
||||
"batch_size": 8,
|
||||
"num_servers": 8,# number of servers to use for local model
|
||||
}
|
||||
call_scorer_server_func = functools.partial(
|
||||
prompt_utils.call_local_server_func,
|
||||
local_model_path=local_model_path,
|
||||
**config
|
||||
)
|
||||
scorer_llm_dict = {"model_type": "local", **config}
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported model: {scorer_llm_name}")
|
||||
|
||||
# if scorer_llm_name == "text-bison":
|
||||
# # when prompting text-bison with Cloud API
|
||||
# scorer_finetuned_palm_temperature = 0.0
|
||||
# scorer_finetuned_palm_max_decode_steps = 1024
|
||||
# scorer_finetuned_palm_batch_size = 1
|
||||
# scorer_finetuned_palm_num_servers = 1
|
||||
# scorer_finetuned_palm_dict = dict()
|
||||
# scorer_finetuned_palm_dict["temperature"] = (
|
||||
# scorer_finetuned_palm_temperature
|
||||
# )
|
||||
# scorer_finetuned_palm_dict["num_servers"] = (
|
||||
# scorer_finetuned_palm_num_servers
|
||||
# )
|
||||
# scorer_finetuned_palm_dict["batch_size"] = scorer_finetuned_palm_batch_size
|
||||
# scorer_finetuned_palm_dict["max_decode_steps"] = (
|
||||
# scorer_finetuned_palm_max_decode_steps
|
||||
# )
|
||||
|
||||
# call_scorer_finetuned_palm_server_func = functools.partial(
|
||||
# prompt_utils.call_palm_server_from_cloud,
|
||||
# model="text-bison-001",
|
||||
# temperature=scorer_finetuned_palm_dict["temperature"],
|
||||
# max_decode_steps=scorer_finetuned_palm_dict["max_decode_steps"],
|
||||
# )
|
||||
|
||||
# scorer_llm_dict = {
|
||||
# "model_type": scorer_llm_name.lower(),
|
||||
# }
|
||||
# scorer_llm_dict.update(scorer_finetuned_palm_dict)
|
||||
# call_scorer_server_func = call_scorer_finetuned_palm_server_func
|
||||
|
||||
# elif scorer_llm_name in {"gpt-3.5-turbo", "gpt-4"}:
|
||||
# # assert scorer_llm_name.lower() in {"gpt-3.5-turbo", "gpt-4"}
|
||||
# scorer_gpt_max_decode_steps = 1024
|
||||
# scorer_gpt_temperature = 0.0
|
||||
|
||||
# scorer_gpt_dict = dict()
|
||||
# scorer_gpt_dict["max_decode_steps"] = scorer_gpt_max_decode_steps
|
||||
# scorer_gpt_dict["temperature"] = scorer_gpt_temperature
|
||||
# scorer_gpt_dict["num_decodes"] = 1
|
||||
# scorer_gpt_dict["batch_size"] = 1
|
||||
# scorer_gpt_dict["num_servers"] = 1
|
||||
|
||||
# scorer_llm_dict = {
|
||||
# "model_type": scorer_llm_name.lower(),
|
||||
# }
|
||||
# scorer_llm_dict.update(scorer_gpt_dict)
|
||||
# call_scorer_server_func = functools.partial(
|
||||
# prompt_utils.call_openai_server_func,
|
||||
# model=scorer_llm_name.lower(),
|
||||
# max_decode_steps=scorer_gpt_max_decode_steps,
|
||||
# temperature=scorer_gpt_temperature,
|
||||
# )
|
||||
# elif scorer_llm_name == "local":
|
||||
# # local vLLM model
|
||||
# scorer_local_max_decode_steps = 1024
|
||||
# scorer_local_temperature = 0.0
|
||||
# call_scorer_server_func = functools.partial(
|
||||
# prompt_utils.call_local_model_server_func,
|
||||
# model_path=local_model_path,
|
||||
# max_decode_steps=scorer_local_max_decode_steps,
|
||||
# temperature=scorer_local_temperature,
|
||||
# )
|
||||
# else:
|
||||
# raise ValueError(
|
||||
# f"Unknown scorer_llm_name: {scorer_llm_name}. "
|
||||
# "It should be one of text-bison, gpt-3.5-turbo, gpt-4, or local."
|
||||
# )
|
||||
|
||||
# ====================== optimizer model configs ============================
|
||||
if optimizer_llm_name.lower() == "text-bison":
|
||||
# PaLM text-bison 模型配置
|
||||
optimizer_llm_dict = {
|
||||
"model_type": "text-bison",
|
||||
"temperature": 1.0, # 更高的随机性以生成多样化解
|
||||
"max_decode_steps": 1024, # 最大生成长度
|
||||
"batch_size": 1, # 单样本处理
|
||||
"num_decodes": 8, # 生成8个候选结果
|
||||
"num_servers": 1 # 单服务器
|
||||
}
|
||||
|
||||
call_optimizer_server_func = functools.partial(
|
||||
prompt_utils.call_palm_server_from_cloud,
|
||||
model="text-bison-001",
|
||||
temperature=optimizer_llm_dict["temperature"],
|
||||
max_decode_steps=optimizer_llm_dict["max_decode_steps"],
|
||||
)
|
||||
|
||||
elif optimizer_llm_name.lower() in {"gpt-3.5-turbo", "gpt-4"}:
|
||||
# GPT 模型配置
|
||||
optimizer_llm_dict = {
|
||||
"model_type": optimizer_llm_name.lower(),
|
||||
"temperature": 1.0, # 更高的随机性
|
||||
"max_decode_steps": 512, # 较短的最大长度
|
||||
"batch_size": 1,
|
||||
"num_decodes": 1 , # 单次生成
|
||||
"num_servers": 1 # 单服务器
|
||||
}
|
||||
|
||||
call_optimizer_server_func = functools.partial(
|
||||
prompt_utils.call_openai_server_func,
|
||||
model=optimizer_llm_name,
|
||||
max_decode_steps=optimizer_llm_dict["max_decode_steps"],
|
||||
temperature=optimizer_llm_dict["temperature"],
|
||||
)
|
||||
elif optimizer_llm_name.lower() == "local":
|
||||
assert local_model_path, "Local model path must be provided."
|
||||
optimizer_llm_dict = {
|
||||
"model_type": optimizer_llm_name.lower(),
|
||||
"temperature": 1.0, # 更高的随机性
|
||||
"max_decode_steps": 512, # 较短的最大长度
|
||||
"batch_size": 8,
|
||||
"num_decodes": 1 , # 单次生成
|
||||
"num_servers": 8 # 单服务器
|
||||
}
|
||||
call_optimizer_server_func = functools.partial(
|
||||
prompt_utils.call_local_server_func,
|
||||
local_model_path=local_model_path,
|
||||
max_decode_steps=optimizer_llm_dict["max_decode_steps"],
|
||||
temperature=optimizer_llm_dict["temperature"],
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported optimizer model: {optimizer_llm_name}. "
|
||||
"Must be one of: text-bison, gpt-3.5-turbo, gpt-4"
|
||||
)
|
||||
|
||||
# if optimizer_llm_name.lower() == "text-bison":
|
||||
# # when prompting text-bison with Cloud API
|
||||
# optimizer_finetuned_palm_temperature = 1.0
|
||||
# optimizer_finetuned_palm_num_decodes = 8
|
||||
# optimizer_finetuned_palm_max_decode_steps = 1024
|
||||
# optimizer_finetuned_palm_batch_size = 1
|
||||
# optimizer_finetuned_palm_num_servers = 1
|
||||
# optimizer_finetuned_palm_dict = dict()
|
||||
# optimizer_finetuned_palm_dict["temperature"] = (
|
||||
# optimizer_finetuned_palm_temperature
|
||||
# )
|
||||
# optimizer_finetuned_palm_dict["num_decodes"] = (
|
||||
# optimizer_finetuned_palm_num_decodes
|
||||
# )
|
||||
# optimizer_finetuned_palm_dict["batch_size"] = (
|
||||
# optimizer_finetuned_palm_batch_size
|
||||
# )
|
||||
# optimizer_finetuned_palm_dict["num_servers"] = (
|
||||
# optimizer_finetuned_palm_num_servers
|
||||
# )
|
||||
# optimizer_finetuned_palm_dict["max_decode_steps"] = (
|
||||
# optimizer_finetuned_palm_max_decode_steps
|
||||
# )
|
||||
|
||||
# call_optimizer_finetuned_palm_server_func = functools.partial(
|
||||
# prompt_utils.call_palm_server_from_cloud,
|
||||
# model="text-bison-001",
|
||||
# temperature=optimizer_finetuned_palm_dict["temperature"],
|
||||
# max_decode_steps=optimizer_finetuned_palm_dict["max_decode_steps"],
|
||||
# )
|
||||
|
||||
# optimizer_llm_dict = {
|
||||
# "model_type": optimizer_llm_name.lower(),
|
||||
# }
|
||||
# optimizer_llm_dict.update(optimizer_finetuned_palm_dict)
|
||||
# call_optimizer_server_func = call_optimizer_finetuned_palm_server_func
|
||||
|
||||
# else:
|
||||
# assert optimizer_llm_name in {"gpt-3.5-turbo", "gpt-4"}
|
||||
# optimizer_gpt_max_decode_steps = 512
|
||||
# optimizer_gpt_temperature = 1.0
|
||||
|
||||
# optimizer_llm_dict = dict()
|
||||
# optimizer_llm_dict["max_decode_steps"] = optimizer_gpt_max_decode_steps
|
||||
# optimizer_llm_dict["temperature"] = optimizer_gpt_temperature
|
||||
# optimizer_llm_dict["batch_size"] = 1
|
||||
# optimizer_llm_dict["num_decodes"] = 1
|
||||
# call_optimizer_server_func = functools.partial(
|
||||
# prompt_utils.call_openai_server_func,
|
||||
# model=optimizer_llm_name,
|
||||
# max_decode_steps=optimizer_gpt_max_decode_steps,
|
||||
# temperature=optimizer_gpt_temperature,
|
||||
# )
|
||||
|
||||
# ====================== try calling the servers ============================
|
||||
print("\n======== testing the scorer and optimizer servers ===========")
|
||||
scorer_test_output = call_scorer_server_func(
|
||||
"Does the sun rise from the north? Just answer yes or no."
|
||||
)
|
||||
print(f"number of scorer output decodes: {len(scorer_test_output)}")
|
||||
print(f"scorer test output: {scorer_test_output}")
|
||||
optimizer_test_output = call_optimizer_server_func(
|
||||
"Does the sun rise from the north? Just answer yes or no.",
|
||||
temperature=1.0,
|
||||
)
|
||||
print(f"number of optimizer output decodes: {len(optimizer_test_output)}")
|
||||
print(f"optimizer test output: {optimizer_test_output}")
|
||||
print("Finished testing the servers.")
|
||||
|
||||
# ====================== read data ============================
|
||||
print("\n================ prompt optimization settings ==============")
|
||||
# from https://github.com/hendrycks/test/blob/master/categories.py
|
||||
subcategories = {
|
||||
"abstract_algebra": ["math"],
|
||||
"anatomy": ["health"],
|
||||
"astronomy": ["physics"],
|
||||
"business_ethics": ["business"],
|
||||
"clinical_knowledge": ["health"],
|
||||
"college_biology": ["biology"],
|
||||
"college_chemistry": ["chemistry"],
|
||||
"college_computer_science": ["computer science"],
|
||||
"college_mathematics": ["math"],
|
||||
"college_medicine": ["health"],
|
||||
"college_physics": ["physics"],
|
||||
"computer_security": ["computer science"],
|
||||
"conceptual_physics": ["physics"],
|
||||
"econometrics": ["economics"],
|
||||
"electrical_engineering": ["engineering"],
|
||||
"elementary_mathematics": ["math"],
|
||||
"formal_logic": ["philosophy"],
|
||||
"global_facts": ["other"],
|
||||
"high_school_biology": ["biology"],
|
||||
"high_school_chemistry": ["chemistry"],
|
||||
"high_school_computer_science": ["computer science"],
|
||||
"high_school_european_history": ["history"],
|
||||
"high_school_geography": ["geography"],
|
||||
"high_school_government_and_politics": ["politics"],
|
||||
"high_school_macroeconomics": ["economics"],
|
||||
"high_school_mathematics": ["math"],
|
||||
"high_school_microeconomics": ["economics"],
|
||||
"high_school_physics": ["physics"],
|
||||
"high_school_psychology": ["psychology"],
|
||||
"high_school_statistics": ["math"],
|
||||
"high_school_us_history": ["history"],
|
||||
"high_school_world_history": ["history"],
|
||||
"human_aging": ["health"],
|
||||
"human_sexuality": ["culture"],
|
||||
"international_law": ["law"],
|
||||
"jurisprudence": ["law"],
|
||||
"logical_fallacies": ["philosophy"],
|
||||
"machine_learning": ["computer science"],
|
||||
"management": ["business"],
|
||||
"marketing": ["business"],
|
||||
"medical_genetics": ["health"],
|
||||
"miscellaneous": ["other"],
|
||||
"moral_disputes": ["philosophy"],
|
||||
"moral_scenarios": ["philosophy"],
|
||||
"nutrition": ["health"],
|
||||
"philosophy": ["philosophy"],
|
||||
"prehistory": ["history"],
|
||||
"professional_accounting": ["other"],
|
||||
"professional_law": ["law"],
|
||||
"professional_medicine": ["health"],
|
||||
"professional_psychology": ["psychology"],
|
||||
"public_relations": ["politics"],
|
||||
"security_studies": ["politics"],
|
||||
"sociology": ["culture"],
|
||||
"us_foreign_policy": ["politics"],
|
||||
"virology": ["health"],
|
||||
"world_religions": ["philosophy"],
|
||||
}
|
||||
|
||||
categories = {
|
||||
"STEM": [
|
||||
"physics",
|
||||
"chemistry",
|
||||
"biology",
|
||||
"computer science",
|
||||
"math",
|
||||
"engineering",
|
||||
],
|
||||
"humanities": ["history", "philosophy", "law"],
|
||||
"social sciences": [
|
||||
"politics",
|
||||
"culture",
|
||||
"economics",
|
||||
"geography",
|
||||
"psychology",
|
||||
],
|
||||
"other (business, health, misc.)": ["other", "business", "health"],
|
||||
}
|
||||
|
||||
if dataset_name == "mmlu":
|
||||
# EITHER: filter by category
|
||||
# category_names = [
|
||||
# "STEM",
|
||||
# "humanities",
|
||||
# "social sciences",
|
||||
# "other (business, health, misc.)",
|
||||
# ]
|
||||
category_names = [task_name]
|
||||
folder_name = "test" # one of {'auxiliary_train', 'dev', 'val', 'test'}
|
||||
task_names = []
|
||||
for task_csv_name in os.listdir(
|
||||
os.path.join(root_data_folder_path, folder_name)
|
||||
):
|
||||
task_names.append(task_csv_name.split(".")[0])
|
||||
|
||||
tasks_in_category = []
|
||||
for category_name in category_names:
|
||||
for task_name in task_names:
|
||||
for subname in subcategories:
|
||||
if subname in task_name:
|
||||
if subcategories[subname][0] in categories[category_name]:
|
||||
tasks_in_category.append(task_name)
|
||||
break
|
||||
|
||||
tasks_all = [(folder_name, task_name) for task_name in tasks_in_category]
|
||||
multiple_choice_tasks = set([item[1] for item in tasks_all])
|
||||
boolean_tasks = set()
|
||||
numerical_output_tasks = set()
|
||||
|
||||
# OR: filter by task
|
||||
# tasks_all = [
|
||||
# # ('test', 'abstract_algebra_test'),
|
||||
# # ('test', 'college_computer_science_test'),
|
||||
# # ('test', 'college_mathematics_test'),
|
||||
# # ('test', 'college_physics_test'),
|
||||
# # ('test', 'elementary_mathematics_test'),
|
||||
# # ('test', 'global_facts_test'),
|
||||
# # ('test', 'high_school_physics_test'),
|
||||
# # ('test', 'machine_learning_test'),
|
||||
# # ('test', 'management_test'),
|
||||
# # ('test', 'medical_genetics_test'),
|
||||
# # ('test', 'moral_scenarios_test'),
|
||||
# # ('test', 'professional_psychology_test'),
|
||||
# # ('test', 'public_relations_test'),
|
||||
# # ('test', 'professional_law_test'),
|
||||
# # ('test', 'high_school_psychology_test'),
|
||||
# # ('test', 'high_school_world_history_test'),
|
||||
# # ('test', 'human_aging_test'),
|
||||
# # ('test', 'miscellaneous_test'),
|
||||
# # ('test', 'moral_scenarios_test'),
|
||||
# ('test', 'professional_psychology_test'),
|
||||
# # ('test', 'security_studies_test'),
|
||||
# ]
|
||||
|
||||
elif dataset_name == "bbh":
|
||||
tasks_all = [task_name]
|
||||
assert (
|
||||
len(tasks_all) == 1
|
||||
), "for now only support prompt optimization on one BBH task"
|
||||
|
||||
# all BBH tasks are as below
|
||||
# tasks_all = [
|
||||
# 'boolean_expressions',
|
||||
# 'causal_judgement',
|
||||
# 'date_understanding',
|
||||
# 'disambiguation_qa',
|
||||
# 'dyck_languages',
|
||||
# 'formal_fallacies',
|
||||
# 'geometric_shapes',
|
||||
# 'hyperbaton',
|
||||
# 'logical_deduction_five_objects',
|
||||
# 'logical_deduction_seven_objects',
|
||||
# 'logical_deduction_three_objects',
|
||||
# 'movie_recommendation',
|
||||
# 'multistep_arithmetic_two',
|
||||
# 'navigate',
|
||||
# 'object_counting',
|
||||
# 'penguins_in_a_table',
|
||||
# 'reasoning_about_colored_objects',
|
||||
# 'ruin_names',
|
||||
# 'salient_translation_error_detection',
|
||||
# 'snarks',
|
||||
# 'sports_understanding',
|
||||
# 'temporal_sequences',
|
||||
# 'tracking_shuffled_objects_five_objects',
|
||||
# 'tracking_shuffled_objects_seven_objects',
|
||||
# 'tracking_shuffled_objects_three_objects',
|
||||
# 'web_of_lies',
|
||||
# 'word_sorting'
|
||||
# ]
|
||||
numerical_output_tasks = {
|
||||
"object_counting",
|
||||
"multistep_arithmetic_two",
|
||||
}
|
||||
|
||||
multiple_choice_tasks = {
|
||||
"date_understanding",
|
||||
"disambiguation_qa",
|
||||
"geometric_shapes",
|
||||
"hyperbaton",
|
||||
"logical_deduction_five_objects",
|
||||
"logical_deduction_seven_objects",
|
||||
"logical_deduction_three_objects",
|
||||
"movie_recommendation",
|
||||
"penguins_in_a_table",
|
||||
"reasoning_about_colored_objects",
|
||||
"ruin_names",
|
||||
"salient_translation_error_detection",
|
||||
"snarks",
|
||||
"temporal_sequences",
|
||||
"tracking_shuffled_objects_five_objects",
|
||||
"tracking_shuffled_objects_seven_objects",
|
||||
"tracking_shuffled_objects_three_objects",
|
||||
}
|
||||
|
||||
boolean_tasks = {
|
||||
"boolean_expressions", # True or False
|
||||
"causal_judgement", # yes or no
|
||||
"formal_fallacies", # valid or invalid
|
||||
"navigate", # yes or no
|
||||
"sports_understanding", # yes or no
|
||||
"web_of_lies", # yes or no
|
||||
}
|
||||
|
||||
else:
|
||||
assert dataset_name in {"gsm8k"}
|
||||
tasks_all = [task_name]
|
||||
multiple_choice_tasks = set()
|
||||
boolean_tasks = set()
|
||||
numerical_output_tasks = set(tasks_all)
|
||||
|
||||
if dataset_name == "mmlu":
|
||||
raw_data = pd.DataFrame()
|
||||
prediction_treat_as_number = False
|
||||
prediction_treat_as_bool = False
|
||||
elif dataset_name == "bbh":
|
||||
raw_data = []
|
||||
prediction_treat_as_number = bool(
|
||||
tasks_all[0] in numerical_output_tasks
|
||||
) # for now only check the first task
|
||||
prediction_treat_as_bool = bool(
|
||||
tasks_all[0] in boolean_tasks
|
||||
) # for now only check the first task
|
||||
print(
|
||||
f"prediction_treat_as_number: {prediction_treat_as_number},"
|
||||
f" prediction_treat_as_bool: {prediction_treat_as_bool}"
|
||||
)
|
||||
else:
|
||||
assert dataset_name == "gsm8k"
|
||||
raw_data = pd.DataFrame()
|
||||
prediction_treat_as_number = True
|
||||
prediction_treat_as_bool = False
|
||||
|
||||
for t in tasks_all:
|
||||
if dataset_name == "mmlu":
|
||||
folder_name = t[0]
|
||||
task_name = t[1]
|
||||
single_task_df = pd.read_csv(
|
||||
os.path.join(root_data_folder_path, f"{folder_name}/{task_name}.csv"),
|
||||
index_col=None,
|
||||
header=None,
|
||||
)
|
||||
raw_data = pd.concat([raw_data, single_task_df])
|
||||
elif dataset_name == "bbh":
|
||||
task_name = t
|
||||
single_task_list = opt_utils.load_bbh_task_data(
|
||||
task_name, base_dir=root_data_folder_path
|
||||
)
|
||||
raw_data += single_task_list
|
||||
else:
|
||||
assert dataset_name == "gsm8k"
|
||||
task_name = t
|
||||
f_gsm = os.path.join(root_data_folder_path, f"gsm_{task_name}.tsv")
|
||||
single_task_df = pd.read_csv(f_gsm, sep="\t", header=None)
|
||||
raw_data = pd.concat([raw_data, single_task_df])
|
||||
|
||||
if dataset_name == "mmlu":
|
||||
num_examples = raw_data.shape[0]
|
||||
elif dataset_name == "bbh":
|
||||
num_examples = len(raw_data)
|
||||
else:
|
||||
assert dataset_name in {"gsm8k"}
|
||||
num_examples = raw_data.shape[0]
|
||||
print(f"number of examples in the current task: {num_examples}")
|
||||
|
||||
# ================ split data into train/val/test ==========================
|
||||
if dataset_name == "mmlu":
|
||||
train_ratio = 0.8
|
||||
eval_ratio = 0.2
|
||||
elif dataset_name == "gsm8k":
|
||||
train_ratio = 0.035
|
||||
eval_ratio = 0
|
||||
else:
|
||||
assert dataset_name == "bbh"
|
||||
train_ratio = 0.2
|
||||
eval_ratio = 0
|
||||
|
||||
# train-validation-test split
|
||||
# It is important to sort the indices, as this ensures the is_multiple_choice
|
||||
# Boolean variables match the data points.
|
||||
assert train_ratio + eval_ratio <= 1
|
||||
test_ratio = 1 - train_ratio - eval_ratio
|
||||
print(
|
||||
f"train_ratio: {train_ratio}, eval_ratio: {eval_ratio}, "
|
||||
f"test_ratio: {test_ratio}"
|
||||
)
|
||||
np.random.seed(0)
|
||||
train_index = np.sort(
|
||||
np.array(
|
||||
np.random.choice(
|
||||
num_examples, size=int(train_ratio * num_examples), replace=False
|
||||
)
|
||||
)
|
||||
)
|
||||
eval_and_test_index = np.sort(
|
||||
np.array(list(set(np.arange(num_examples)) - set(train_index)))
|
||||
)
|
||||
eval_index = np.sort(
|
||||
np.array(
|
||||
np.random.choice(
|
||||
eval_and_test_index,
|
||||
size=int(eval_ratio * num_examples),
|
||||
replace=False,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# ========== set other optimization experiment hyperparameters ==============
|
||||
if scorer_llm_name == "text-bison":
|
||||
old_instruction_score_threshold = 0.0 # 完全保留旧指令 表示不过滤任何历史指令(即使质量很低的旧指令也会保留)。
|
||||
# old_instruction_score_threshold = 0.15 # for GSM8K
|
||||
elif scorer_llm_name == "local":
|
||||
old_instruction_score_threshold = 0.3
|
||||
else:
|
||||
assert scorer_llm_name in {"gpt-3.5-turbo", "gpt-4"} # 模型校验
|
||||
old_instruction_score_threshold = 0.3 # 过滤低质量旧指令
|
||||
|
||||
if scorer_llm_name == "text-bison":
|
||||
extract_final_answer_by_prompting_again = False # 是否通过二次提示提取最终答案(例如从冗长响应中提取关键内容)
|
||||
include_qa = False # 是否在元提示中包含问答对
|
||||
evaluate_in_parallel = False # 是否并行评估
|
||||
elif scorer_llm_name == "local":
|
||||
extract_final_answer_by_prompting_again = True
|
||||
include_qa = True
|
||||
evaluate_in_parallel = True
|
||||
else:
|
||||
assert scorer_llm_name in {"gpt-3.5-turbo", "gpt-4"}
|
||||
extract_final_answer_by_prompting_again = False
|
||||
include_qa = False
|
||||
evaluate_in_parallel = False
|
||||
|
||||
optimizer_llm_temperature = optimizer_llm_dict["temperature"]
|
||||
|
||||
num_few_shot_questions_for_instruction_refinement = 3 # number of few-shot questions 每次优化指令时参考的少样本示例数量(Few-shot QA对)。
|
||||
|
||||
# To change the number of generated instructions in each step, one should
|
||||
# edit the value of the variable below, instead of editing the number of
|
||||
# decodes in model parameters, because those values are limited by model
|
||||
# serving configs.
|
||||
num_generated_instructions_in_each_step = 3 # number of generated instructions in each step 每轮搜索生成的候选指令数量。
|
||||
num_search_steps = 50 # number of search steps 总优化迭代次数。
|
||||
|
||||
initial_instructions = [
|
||||
"Let's solve the problem.",
|
||||
# "",
|
||||
# "The answer is",
|
||||
]
|
||||
few_shot_qa_pairs = True #是否使用少样本示例指导指令生成。
|
||||
# one of {'accumulative_most_frequent', 'current_most_frequent', 'random',
|
||||
# 'constant'}
|
||||
few_shot_selection_criteria = "random" #对多样性要求高时用 random,稳定性要求高时用 most_frequent。
|
||||
# whether to evaluate generated instructions on the exemplars in meta-prompt
|
||||
evaluate_generated_ins_on_few_shot = False # 是否评估新指令 开发阶段设为 True调试指令质量。
|
||||
# whether to evaluate old instructions on the exemplars in the meta-prompt
|
||||
evaluate_old_ins_on_few_shot = False # 是否评估旧指令 生产阶段设为 False加速运行。
|
||||
# every this number of steps, compute the accuracies of current-step
|
||||
# instructions on the validation set
|
||||
eval_interval = 3 # 每N步在验证集上测试当前指令的准确率。
|
||||
|
||||
max_num_instructions = (
|
||||
20 # 元提示中保留的历史指令数量上限。
|
||||
)
|
||||
# 将连续分数离散化为N档(如0-100整数),简化模型理解。
|
||||
num_score_buckets = 100
|
||||
# whether to put old instructions and scores to before exemplars in
|
||||
# 控制元提示中历史指令和少样本示例的顺序。
|
||||
meta_prompt_instructions_before_exemplars = True
|
||||
|
||||
# ===================== run prompt optimization ======================
|
||||
|
||||
assert few_shot_selection_criteria in {
|
||||
"accumulative_most_frequent",
|
||||
"current_most_frequent",
|
||||
"random",
|
||||
"constant",
|
||||
}
|
||||
evolution_kwargs = {
|
||||
"num_search_steps": num_search_steps,
|
||||
"old_instruction_score_threshold": old_instruction_score_threshold,
|
||||
"scorer_llm_dict": scorer_llm_dict,
|
||||
"optimizer_llm_dict": optimizer_llm_dict,
|
||||
"extract_final_answer_by_prompting_again": (
|
||||
extract_final_answer_by_prompting_again
|
||||
),
|
||||
"include_qa": include_qa,
|
||||
"evaluate_in_parallel": evaluate_in_parallel,
|
||||
"tasks_all": tasks_all,
|
||||
"train_ratio": train_ratio,
|
||||
"eval_ratio": eval_ratio,
|
||||
"test_ratio": test_ratio,
|
||||
"train_index": train_index,
|
||||
"eval_index": eval_index,
|
||||
"dataset_name": dataset_name,
|
||||
"task_name": task_name,
|
||||
"num_examples": num_examples,
|
||||
"root_data_folder_path": root_data_folder_path,
|
||||
"optimizer_llm_temperature": optimizer_llm_temperature,
|
||||
# "optimizer_llm_temperature_schedule": (
|
||||
# optimizer_llm_temperature_schedule
|
||||
# ),
|
||||
# "optimizer_llm_temperature_end": optimizer_llm_temperature_end,
|
||||
"initial_instructions": initial_instructions,
|
||||
"multiple_choice_tasks": multiple_choice_tasks,
|
||||
"raw_data": raw_data,
|
||||
"call_scorer_server_func": call_scorer_server_func,
|
||||
"call_optimizer_server_func": call_optimizer_server_func,
|
||||
"instruction_pos": instruction_pos,
|
||||
"prediction_treat_as_number": prediction_treat_as_number,
|
||||
"prediction_treat_as_bool": prediction_treat_as_bool,
|
||||
"result_by_instruction_folder": result_by_instruction_folder,
|
||||
"few_shot_qa_pairs": few_shot_qa_pairs,
|
||||
"num_score_buckets": num_score_buckets,
|
||||
"max_num_instructions": max_num_instructions,
|
||||
"meta_prompt_type": meta_prompt_type,
|
||||
"meta_prompt_instructions_before_exemplars": (
|
||||
meta_prompt_instructions_before_exemplars
|
||||
),
|
||||
"few_shot_selection_criteria": few_shot_selection_criteria,
|
||||
"optimizer_llm_name": optimizer_llm_name,
|
||||
"num_generated_instructions_in_each_step": (
|
||||
num_generated_instructions_in_each_step
|
||||
),
|
||||
"evaluate_generated_ins_on_few_shot": evaluate_generated_ins_on_few_shot,
|
||||
"num_few_shot_questions_for_instruction_refinement": (
|
||||
num_few_shot_questions_for_instruction_refinement
|
||||
),
|
||||
"evaluate_old_ins_on_few_shot": evaluate_old_ins_on_few_shot,
|
||||
"eval_interval": eval_interval,
|
||||
"save_folder": save_folder,
|
||||
}
|
||||
|
||||
opt_utils.run_evolution(**evolution_kwargs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(main)
|
||||
@@ -1,2 +0,0 @@
|
||||
{"session_id": "06e84e4b03c54df188961f4e033f9e77", "round": 1, "choice": "哪些奶茶含有较少的糖分?", "answer": "选择低糖奶茶时,可以通过品牌选择、糖分选项调整以及配料搭配来控制摄入的糖分。以下是一些具体建议和推荐:\n\n---\n\n### **一、低糖奶茶品牌及选项**\n1. **喜茶(HEYTEA)** \n - **推荐系列**:轻乳茶(如「轻乳茉莉奶绿」)、部分限定款(如「芝士奶盖茶」可选低糖)。 \n - **糖分控制**:默认糖度为50%,可选「三分糖」(约10%糖度)或「无糖」(使用代糖)。 \n - **注意**:奶盖本身含糖,建议选择「无奶盖」选项。\n\n2. **奈雪的茶** \n - **推荐系列**:鲜果茶(如「霸气橙子」)、「鲜果轻乳茶」系列。 \n - **糖分控制**:默认糖度为50%,可选「少糖」(约30%)或「无糖」(使用代糖)。 \n - **注意**:鲜果茶的天然果糖含量较高,需结合糖分选项综合判断。\n\n3. **CoCo都可** \n - **推荐系列**:「鲜果茶」(如「百香果鲜果茶」)、「轻乳茶」。 \n - **糖分控制**:默认全糖(约100%),可选「三分糖」(约30%)或「无糖」(代糖)。 \n - **注意**:部分经典奶茶(如「珍珠奶茶」)糖分较高,建议选择鲜果类饮品。\n\n4. **一点点** \n - **推荐系列**:「鲜果茶」(如「百香果鲜果茶」)、「轻乳茶」。 \n - **糖分控制**:默认全糖(约100%),可选「半糖」(50%)或「无糖」(代糖)。 \n - **注意**:部分奶盖茶含糖量较高,建议选择鲜果或轻乳系列。\n\n---\n\n### **二、如何选择低糖奶茶?**\n1. **糖分选项优先级** \n - **无糖/代糖**:使用赤藓糖醇、甜菊糖等代糖,几乎不含热量(但需注意部分人对代糖敏感)。 \n - **三分糖/半糖**:糖分约为全糖的30%-50%,适合控制摄入。 \n - **避免全糖**:全糖奶茶的糖分通常在"}
|
||||
{"session_id": "af370656137a4b8f97d8ffac2c60b191", "round": 1, "choice": "瑞幸咖啡有哪些值得推荐的饮品?", "answer": "瑞幸咖啡近年来凭借高性价比和创新口味,推出了许多受欢迎的饮品。以下是一些值得推荐的饮品,涵盖经典款、特色新品和季节限定,适合不同口味和需求:\n\n---\n\n### **一、爆款必试系列**\n1. **生椰拿铁** \n - **推荐理由**:瑞幸的“现象级”产品,椰香浓郁与咖啡的苦香完美融合,口感顺滑,常年占据销量榜首。 \n - **小贴士**:可尝试“生椰拿铁+巧克力糖浆”组合,变成“生椰摩卡”,风味更丰富。\n\n2. **丝绒拿铁** \n - **推荐理由**:以红茶为基底,搭配丝滑牛奶,茶香与奶香平衡,适合喜欢茶咖融合的人群。 \n - **特色**:选用锡兰红茶,口感更醇厚,冷热皆宜。\n\n3. **厚乳拿铁** \n - **推荐理由**:使用厚乳(高乳脂含量的牛奶),奶香更浓郁,适合追求绵密口感的爱好者。\n\n---\n\n### **二、果味与创意系列**\n1. **冰椰拿铁**(夏季限定) \n - **推荐理由**:生椰拿铁的冰饮版本,加入冰块和椰香糖浆,清爽解暑,适合夏天。\n\n2. **蓝莓生椰拿铁** \n - **推荐理由**:在生椰拿铁基础上加入蓝莓糖浆,果香与椰香交织,甜而不腻。\n\n3. **蜜桃生椰拿铁** \n - **推荐理由**:蜜桃风味糖浆与生椰拿1:1搭配,清新果香与咖啡的苦香碰撞,适合喜欢果味的人。\n\n---\n\n### **三、季节限定款**\n1. **桂花拿铁**(秋季限定) \n - **推荐理由**:桂花糖浆与拿铁结合,香气扑鼻,甜度适中,是"}
|
||||
7
requirements.txt
Normal file
7
requirements.txt
Normal file
@@ -0,0 +1,7 @@
|
||||
fastapi==0.109.0
|
||||
uvicorn==0.27.0
|
||||
requests==2.31.0
|
||||
numpy==1.26.3
|
||||
scikit-learn==1.4.0
|
||||
pydantic==2.5.3
|
||||
|
||||
184
test_opro_api.py
Normal file
184
test_opro_api.py
Normal file
@@ -0,0 +1,184 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for TRUE OPRO API endpoints.
|
||||
|
||||
This script tests the complete OPRO workflow:
|
||||
1. Create OPRO run
|
||||
2. Generate initial candidates
|
||||
3. Evaluate candidates
|
||||
4. Generate optimized candidates
|
||||
5. View results
|
||||
|
||||
Usage:
|
||||
python test_opro_api.py
|
||||
"""
|
||||
|
||||
import requests
|
||||
import json
|
||||
import time
|
||||
|
||||
BASE_URL = "http://127.0.0.1:8010"
|
||||
|
||||
def print_section(title):
|
||||
"""Print a section header."""
|
||||
print("\n" + "=" * 60)
|
||||
print(f" {title}")
|
||||
print("=" * 60)
|
||||
|
||||
def test_opro_workflow():
|
||||
"""Test the complete OPRO workflow."""
|
||||
|
||||
print_section("1. Create OPRO Run")
|
||||
|
||||
# Create a new OPRO run
|
||||
create_req = {
|
||||
"task_description": "将用户输入的中文翻译成英文,要求准确自然",
|
||||
"test_cases": [
|
||||
{"input": "你好", "expected_output": "Hello"},
|
||||
{"input": "谢谢", "expected_output": "Thank you"},
|
||||
{"input": "早上好", "expected_output": "Good morning"},
|
||||
{"input": "晚安", "expected_output": "Good night"},
|
||||
{"input": "再见", "expected_output": "Goodbye"}
|
||||
]
|
||||
}
|
||||
|
||||
response = requests.post(f"{BASE_URL}/opro/create", json=create_req)
|
||||
result = response.json()
|
||||
|
||||
if not result.get("success"):
|
||||
print(f"❌ Failed to create OPRO run: {result}")
|
||||
return
|
||||
|
||||
run_id = result["data"]["run_id"]
|
||||
print(f"✅ Created OPRO run: {run_id}")
|
||||
print(f" Task: {result['data']['task_description']}")
|
||||
print(f" Test cases: {result['data']['num_test_cases']}")
|
||||
|
||||
# ========================================================================
|
||||
print_section("2. Generate Initial Candidates")
|
||||
|
||||
iterate_req = {"run_id": run_id, "top_k": 5}
|
||||
response = requests.post(f"{BASE_URL}/opro/iterate", json=iterate_req)
|
||||
result = response.json()
|
||||
|
||||
if not result.get("success"):
|
||||
print(f"❌ Failed to generate candidates: {result}")
|
||||
return
|
||||
|
||||
candidates = result["data"]["candidates"]
|
||||
print(f"✅ Generated {len(candidates)} initial candidates:")
|
||||
for i, candidate in enumerate(candidates, 1):
|
||||
print(f"\n [{i}] {candidate[:100]}...")
|
||||
|
||||
# ========================================================================
|
||||
print_section("3. Evaluate Candidates")
|
||||
|
||||
scores = []
|
||||
for i, candidate in enumerate(candidates, 1):
|
||||
print(f"\n Evaluating candidate {i}/{len(candidates)}...")
|
||||
|
||||
eval_req = {
|
||||
"run_id": run_id,
|
||||
"instruction": candidate
|
||||
}
|
||||
|
||||
response = requests.post(f"{BASE_URL}/opro/evaluate", json=eval_req)
|
||||
result = response.json()
|
||||
|
||||
if result.get("success"):
|
||||
score = result["data"]["score"]
|
||||
scores.append(score)
|
||||
is_best = "🏆" if result["data"]["is_new_best"] else ""
|
||||
print(f" ✅ Score: {score:.4f} {is_best}")
|
||||
else:
|
||||
print(f" ❌ Evaluation failed: {result}")
|
||||
|
||||
time.sleep(0.5) # Small delay to avoid overwhelming the API
|
||||
|
||||
print(f"\n Average score: {sum(scores)/len(scores):.4f}")
|
||||
print(f" Best score: {max(scores):.4f}")
|
||||
|
||||
# ========================================================================
|
||||
print_section("4. Generate Optimized Candidates (Iteration 2)")
|
||||
|
||||
print(" Generating candidates based on performance trajectory...")
|
||||
|
||||
iterate_req = {"run_id": run_id, "top_k": 5}
|
||||
response = requests.post(f"{BASE_URL}/opro/iterate", json=iterate_req)
|
||||
result = response.json()
|
||||
|
||||
if not result.get("success"):
|
||||
print(f"❌ Failed to generate optimized candidates: {result}")
|
||||
return
|
||||
|
||||
optimized_candidates = result["data"]["candidates"]
|
||||
print(f"✅ Generated {len(optimized_candidates)} optimized candidates:")
|
||||
for i, candidate in enumerate(optimized_candidates, 1):
|
||||
print(f"\n [{i}] {candidate[:100]}...")
|
||||
|
||||
# ========================================================================
|
||||
print_section("5. View Run Details")
|
||||
|
||||
response = requests.get(f"{BASE_URL}/opro/run/{run_id}")
|
||||
result = response.json()
|
||||
|
||||
if not result.get("success"):
|
||||
print(f"❌ Failed to get run details: {result}")
|
||||
return
|
||||
|
||||
data = result["data"]
|
||||
print(f"✅ OPRO Run Details:")
|
||||
print(f" Run ID: {data['run_id']}")
|
||||
print(f" Task: {data['task_description']}")
|
||||
print(f" Iteration: {data['iteration']}")
|
||||
print(f" Status: {data['status']}")
|
||||
print(f" Best Score: {data['best_score']:.4f}")
|
||||
print(f"\n Best Instruction:")
|
||||
print(f" {data['best_instruction'][:200]}...")
|
||||
|
||||
print(f"\n Top 5 Trajectory:")
|
||||
for i, item in enumerate(data['trajectory'][:5], 1):
|
||||
print(f" [{i}] Score: {item['score']:.4f}")
|
||||
print(f" {item['instruction'][:80]}...")
|
||||
|
||||
# ========================================================================
|
||||
print_section("6. List All Runs")
|
||||
|
||||
response = requests.get(f"{BASE_URL}/opro/runs")
|
||||
result = response.json()
|
||||
|
||||
if result.get("success"):
|
||||
runs = result["data"]["runs"]
|
||||
print(f"✅ Total OPRO runs: {result['data']['total']}")
|
||||
for run in runs:
|
||||
print(f"\n Run: {run['run_id']}")
|
||||
print(f" Task: {run['task_description'][:50]}...")
|
||||
print(f" Iteration: {run['iteration']}, Best Score: {run['best_score']:.4f}")
|
||||
|
||||
print_section("✅ OPRO Workflow Test Complete!")
|
||||
print(f"\nRun ID: {run_id}")
|
||||
print("You can view details at:")
|
||||
print(f" {BASE_URL}/opro/run/{run_id}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("=" * 60)
|
||||
print(" TRUE OPRO API Test")
|
||||
print("=" * 60)
|
||||
print(f"\nBase URL: {BASE_URL}")
|
||||
print("\nMake sure the API server is running:")
|
||||
print(" uvicorn _qwen_xinference_demo.api:app --host 127.0.0.1 --port 8010")
|
||||
print("\nStarting test in 3 seconds...")
|
||||
time.sleep(3)
|
||||
|
||||
try:
|
||||
test_opro_workflow()
|
||||
except requests.exceptions.ConnectionError:
|
||||
print("\n❌ ERROR: Could not connect to API server")
|
||||
print("Please start the server first:")
|
||||
print(" uvicorn _qwen_xinference_demo.api:app --host 127.0.0.1 --port 8010")
|
||||
except Exception as e:
|
||||
print(f"\n❌ ERROR: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
131
test_session_api.py
Normal file
131
test_session_api.py
Normal file
@@ -0,0 +1,131 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for OPRO session-based API
|
||||
"""
|
||||
|
||||
import requests
|
||||
import json
|
||||
|
||||
BASE_URL = "http://127.0.0.1:8010"
|
||||
|
||||
def print_section(title):
|
||||
print(f"\n{'='*60}")
|
||||
print(f" {title}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
def test_session_workflow():
|
||||
"""Test the complete session-based workflow."""
|
||||
|
||||
print_section("1. Create Session")
|
||||
|
||||
# Create a new session
|
||||
response = requests.post(f"{BASE_URL}/opro/session/create")
|
||||
result = response.json()
|
||||
|
||||
if not result.get("success"):
|
||||
print(f"❌ Failed to create session: {result}")
|
||||
return
|
||||
|
||||
session_id = result["data"]["session_id"]
|
||||
print(f"✅ Session created: {session_id}")
|
||||
print(f" Session name: {result['data']['session_name']}")
|
||||
|
||||
print_section("2. Create First Run in Session")
|
||||
|
||||
# Create first run
|
||||
create_req = {
|
||||
"task_description": "将中文翻译成英文",
|
||||
"test_cases": [
|
||||
{"input": "你好", "expected_output": "Hello"},
|
||||
{"input": "谢谢", "expected_output": "Thank you"}
|
||||
],
|
||||
"session_id": session_id
|
||||
}
|
||||
|
||||
response = requests.post(f"{BASE_URL}/opro/create", json=create_req)
|
||||
result = response.json()
|
||||
|
||||
if not result.get("success"):
|
||||
print(f"❌ Failed to create run: {result}")
|
||||
return
|
||||
|
||||
run1_id = result["data"]["run_id"]
|
||||
print(f"✅ Run 1 created: {run1_id}")
|
||||
print(f" Task: {result['data']['task_description']}")
|
||||
|
||||
print_section("3. Create Second Run in Same Session")
|
||||
|
||||
# Create second run in same session
|
||||
create_req2 = {
|
||||
"task_description": "将英文翻译成中文",
|
||||
"test_cases": [
|
||||
{"input": "Hello", "expected_output": "你好"},
|
||||
{"input": "Thank you", "expected_output": "谢谢"}
|
||||
],
|
||||
"session_id": session_id
|
||||
}
|
||||
|
||||
response = requests.post(f"{BASE_URL}/opro/create", json=create_req2)
|
||||
result = response.json()
|
||||
|
||||
if not result.get("success"):
|
||||
print(f"❌ Failed to create run 2: {result}")
|
||||
return
|
||||
|
||||
run2_id = result["data"]["run_id"]
|
||||
print(f"✅ Run 2 created: {run2_id}")
|
||||
print(f" Task: {result['data']['task_description']}")
|
||||
|
||||
print_section("4. Get Session Details")
|
||||
|
||||
response = requests.get(f"{BASE_URL}/opro/session/{session_id}")
|
||||
result = response.json()
|
||||
|
||||
if not result.get("success"):
|
||||
print(f"❌ Failed to get session: {result}")
|
||||
return
|
||||
|
||||
print(f"✅ Session details:")
|
||||
print(f" Session ID: {result['data']['session_id']}")
|
||||
print(f" Session name: {result['data']['session_name']}")
|
||||
print(f" Number of runs: {result['data']['num_runs']}")
|
||||
print(f" Runs:")
|
||||
for run in result['data']['runs']:
|
||||
print(f" - {run['run_id'][:8]}... : {run['task_description']}")
|
||||
|
||||
print_section("5. List All Sessions")
|
||||
|
||||
response = requests.get(f"{BASE_URL}/opro/sessions")
|
||||
result = response.json()
|
||||
|
||||
if not result.get("success"):
|
||||
print(f"❌ Failed to list sessions: {result}")
|
||||
return
|
||||
|
||||
print(f"✅ Total sessions: {len(result['data']['sessions'])}")
|
||||
for session in result['data']['sessions']:
|
||||
print(f" - {session['session_name']}: {session['num_runs']} runs")
|
||||
|
||||
print_section("✅ All Tests Passed!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
# Check if server is running
|
||||
response = requests.get(f"{BASE_URL}/health")
|
||||
if response.status_code != 200:
|
||||
print("❌ Server is not running. Please start it with:")
|
||||
print(" uvicorn _qwen_xinference_demo.api:app --host 127.0.0.1 --port 8010")
|
||||
exit(1)
|
||||
|
||||
test_session_workflow()
|
||||
|
||||
except requests.exceptions.ConnectionError:
|
||||
print("❌ Cannot connect to server. Please start it with:")
|
||||
print(" uvicorn _qwen_xinference_demo.api:app --host 127.0.0.1 --port 8010")
|
||||
exit(1)
|
||||
except Exception as e:
|
||||
print(f"❌ Error: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
exit(1)
|
||||
|
||||
Reference in New Issue
Block a user