模型的训练、保存、加载请参考前面的博客:
ray.rllib 入门实践-5: 训练算法-CSDN博客
ray.rllib 入门实践-6: 保存模型-CSDN博客
ray.rllib 入门实践-7: 加载训练好的模型-CSDN博客
本博客仅根据推荐的训练、保存、加载模型的方法产生并加载模型,然后介绍两种模型评估的方法。
环境配置:
torch==2.5.1
ray==2.10.0
ray[rllib]==2.10.0
ray[tune]==2.10.0
ray[serve]==2.10.0
numpy==1.23.0
python==3.9.18
一、模型的训练、保存
import os
from ray.rllib.algorithms.ppo import PPO,PPOConfig
from ray.tune.logger import pretty_print## 配置算法
storage_path = "F:/codes/RLlib_study/ray_rllib_tutorials/outputs"
os.makedirs(storage_path,exist_ok=True)config = PPOConfig()
config = config.rollouts(num_rollout_workers=2)
config = config.resources(num_gpus=0,num_cpus_per_worker=1,num_gpus_per_worker=0)
config = config.environment(env="CartPole-v1",env_config={})
config = config.evaluation(evaluation_num_workers=1) ## 要想调用 algo 的evaluation功能,需要在这里进行设置,否则不work.
config.output = storage_path ## 设置过程文件的存储路径## 构建算法
algo = config.build()## 训练算法
for i in range(3):result = algo.train() print(f"episode_{i}")## 保存模型
checkpoint_dir = "F:/codes/RLlib_study/ray_rllib_tutorials/outputs/checkpoints"
os.makedirs(checkpoint_dir,exist_ok=True)
algo.save_checkpoint(checkpoint_dir) ## 保存到指定路径下
print(f"saved checkpoint to {checkpoint_dir}")
二、 模型评估
方式1:多轮统计评估
## 方式1: algo.evaluation() .
## 本方法的前提是,在算法训练过程中,给算法配置上 evaluation 相关选项, 否则该方法失败。
## 本方法执行了多个 episode, 并对结果进行统计,返回统计结果。
## 加载模型
checkpoint_dir = "F:/codes/RLlib_study/ray_rllib_tutorials/outputs/checkpoints"
algo = PPO.from_checkpoint(checkpoint_dir)
print(f"Loaded checkpoint from {checkpoint_dir}")
## 评估模型
evaluation_result = algo.evaluate() ## 需要在算法训练阶段,给算法配置上 evaluation 选项, 否则此处调用失败
print(pretty_print(evaluation_result))
方式2:单轮评估
import gymnasium as gym
## 创建环境
env_name = "CartPole-v1"
env = gym.make(env_name)
## 加载模型
checkpoint_dir = "F:/codes/RLlib_study/ray_rllib_tutorials/outputs/checkpoints"
algo = PPO.from_checkpoint(checkpoint_dir)
print(f"Loaded checkpoint from {checkpoint_dir}")
## 模型推断
step = 0
episode_reward = 0
terminated = truncated = False
obs,info = env.reset()
while not terminated and not truncated:action = algo.compute_single_action(obs)obs, reward, terminated, truncated, info = env.step(action)episode_reward += rewardstep += 1print(f"step = {step}, reward = {reward},\action = {action}, obs = {obs}, \episode_reward = {episode_reward}")
三、代码汇总
import os
from ray.rllib.algorithms.ppo import PPO,PPOConfig
from ray.tune.logger import pretty_print
import gymnasium as gym ## 配置算法
storage_path = "F:/codes/RLlib_study/ray_rllib_tutorials/outputs"
os.makedirs(storage_path,exist_ok=True)config = PPOConfig()
config = config.rollouts(num_rollout_workers=2)
config = config.resources(num_gpus=0,num_cpus_per_worker=1,num_gpus_per_worker=0)
config = config.environment(env="CartPole-v1",env_config={})
config = config.evaluation(evaluation_num_workers=1) ## 要想调用 algo 的evaluation功能,需要在这里进行设置,否则不work.
config.output = storage_path ## 设置过程文件的存储路径## 构建算法
algo = config.build()## 训练算法
for i in range(3):result = algo.train() print(f"episode_{i}")## 保存模型
checkpoint_dir = "F:/codes/RLlib_study/ray_rllib_tutorials/outputs/checkpoints"
os.makedirs(checkpoint_dir,exist_ok=True)
algo.save_checkpoint(checkpoint_dir) ## 保存到指定路径下
print(f"saved checkpoint to {checkpoint_dir}")################# evaluate #############################
# ## 方式1: algo.evaluation() . 执行了多个 episode, 并对结果进行统计,返回统计结果。
# ## 加载模型
# checkpoint_dir = "F:/codes/RLlib_study/ray_rllib_tutorials/outputs/checkpoints"
# algo = PPO.from_checkpoint(checkpoint_dir)
# print(f"Loaded checkpoint from {checkpoint_dir}")
# ## 评估模型
# evaluation_result = algo.evaluate() ## 需要在算法训练阶段,给算法配置上 evaluation 选项, 否则此处调用失败
# print(pretty_print(evaluation_result))## 方式 2:algo.compute_single_action(obs)
## 创建环境
env_name = "CartPole-v1"
env = gym.make(env_name)
## 加载模型
checkpoint_dir = "F:/codes/RLlib_study/ray_rllib_tutorials/outputs/checkpoints"
algo = PPO.from_checkpoint(checkpoint_dir)
print(f"Loaded checkpoint from {checkpoint_dir}")
## 模型推断
step = 0
episode_reward = 0
terminated = truncated = False
obs,info = env.reset()
while not terminated and not truncated:action = algo.compute_single_action(obs)obs, reward, terminated, truncated, info = env.step(action)episode_reward += rewardstep += 1print(f"step = {step}, reward = {reward},\action = {action}, obs = {obs}, \episode_reward = {episode_reward}")
