官方github:https://github.com/hiyouga/EasyR1
参考:https://opendeep.wiki/hiyouga/EasyR1/quickstart
代码和环境配置
github:https://github.com/hiyouga/EasyR1
新建一个虚拟环境:
python -m venv easyr1
source easyr1/bin/activate
python -m pip install transformers==4.51.0
pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1
python -m pip install wheel
python -m pip install flash-attn==2.7.4.post1
python -m pip install vllm==0.8.3
安装:
git clone https://github.com/hiyouga/EasyR1.git
cd EasyR1
pip install -e .
数据集
text数据集:https://huggingface.co/datasets/hiyouga/math12k
参考数据集构建
参考代码:
import json
import os
from datasets import Dataset, DatasetDictdef generate_data(data_path: str):with open(data_path, "r", encoding="utf-8") as f:for line in f:data = json.loads(line)yield {"problem": data["problem"],"answer": data["answer"],}def main():trainset = Dataset.from_generator(generate_data, gen_kwargs={"data_path": os.path.join("prm800k", "math_splits", "train.jsonl")})testset = Dataset.from_generator(generate_data, gen_kwargs={"data_path": os.path.join("prm800k", "math_splits", "test.jsonl")})dataset = DatasetDict({"train": trainset, "test": testset})dataset.push_to_hub("hiyouga/math12k")if __name__ == "__main__":main()
主要修改的参数
参数含义
参数含义2
config路径:examples/config.yaml
data
- train_files训练集路径
- val_files测试集路径
- max_prompt_length:输入长度限制
- max_response_length:输出长度限制
- rollout_batch_size:
- mini_rollout_batch_size:
- format_prompt:根据llm来定对应的jinja文件
worker
- actor.model:模型路径
- rollout.n:一条数据组内采样几条样本,默认5,我设置的8
- reward.reward_function:reward函数路径。
trainer
- experiment_name:实验名称
遇到的报错
- 代码卡在“Started a local Ray instance. View the dashboard at 127.0.0.1:8265”不动
- failed to register worker to ralylet: IOError
这俩问题合在一起。
参考解决方式1
参考解决方式2
做法:
所有的bs都改成1,除了global_batch_size是gpu数量。并rollout batch_size的必须是global_batch_size的倍数,我给rollout_batch_size开了8或16。
代码路径:verl/trainer/config.py
调整参数
worker:reward:num_cpus: 1
此外强制修改num_cpus:
/mnt/gemininjceph3/geminicephfs/mmsearch-luban-universal/group_2/user_skylarshao/EasyR1/verl/trainer/main.py
ray.init(runtime_env=runtime_env)
改成ray.init(runtime_env=runtime_env, num_cpus=1)