欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 科技 > IT业 > ray.rllib-入门实践-8: 模型推断与评估

ray.rllib-入门实践-8: 模型推断与评估

2025/11/7 2:15:36 来源:https://blog.csdn.net/Strive_For_Future/article/details/145353407  浏览:    关键词:ray.rllib-入门实践-8: 模型推断与评估

模型的训练、保存、加载请参考前面的博客:

        ray.rllib 入门实践-5: 训练算法-CSDN博客

        ray.rllib 入门实践-6: 保存模型-CSDN博客

        ray.rllib 入门实践-7: 加载训练好的模型-CSDN博客

本博客仅根据推荐的训练、保存、加载模型的方法产生并加载模型,然后介绍两种模型评估的方法。

环境配置:

        torch==2.5.1

        ray==2.10.0

        ray[rllib]==2.10.0

        ray[tune]==2.10.0

        ray[serve]==2.10.0

        numpy==1.23.0

        python==3.9.18

一、模型的训练、保存

import os 
from ray.rllib.algorithms.ppo import PPO,PPOConfig
from ray.tune.logger import pretty_print## 配置算法
storage_path = "F:/codes/RLlib_study/ray_rllib_tutorials/outputs"
os.makedirs(storage_path,exist_ok=True)config = PPOConfig()
config = config.rollouts(num_rollout_workers=2)
config = config.resources(num_gpus=0,num_cpus_per_worker=1,num_gpus_per_worker=0)
config = config.environment(env="CartPole-v1",env_config={})
config = config.evaluation(evaluation_num_workers=1)  ## 要想调用 algo 的evaluation功能,需要在这里进行设置,否则不work.
config.output = storage_path  ## 设置过程文件的存储路径## 构建算法
algo = config.build()## 训练算法
for i in range(3):result = algo.train() print(f"episode_{i}")## 保存模型
checkpoint_dir = "F:/codes/RLlib_study/ray_rllib_tutorials/outputs/checkpoints"
os.makedirs(checkpoint_dir,exist_ok=True)
algo.save_checkpoint(checkpoint_dir) ## 保存到指定路径下
print(f"saved checkpoint to {checkpoint_dir}")

二、 模型评估

方式1:多轮统计评估

## 方式1: algo.evaluation() . 
## 本方法的前提是,在算法训练过程中,给算法配置上 evaluation 相关选项, 否则该方法失败。
## 本方法执行了多个 episode, 并对结果进行统计,返回统计结果。
## 加载模型
checkpoint_dir = "F:/codes/RLlib_study/ray_rllib_tutorials/outputs/checkpoints"
algo = PPO.from_checkpoint(checkpoint_dir)
print(f"Loaded checkpoint from {checkpoint_dir}")
## 评估模型
evaluation_result = algo.evaluate() ## 需要在算法训练阶段,给算法配置上 evaluation 选项, 否则此处调用失败
print(pretty_print(evaluation_result))

方式2:单轮评估

import gymnasium as gym 
## 创建环境
env_name = "CartPole-v1"
env = gym.make(env_name)
## 加载模型
checkpoint_dir = "F:/codes/RLlib_study/ray_rllib_tutorials/outputs/checkpoints"
algo = PPO.from_checkpoint(checkpoint_dir)
print(f"Loaded checkpoint from {checkpoint_dir}")
## 模型推断
step = 0
episode_reward = 0
terminated = truncated = False 
obs,info = env.reset()
while not terminated and not truncated:action = algo.compute_single_action(obs)obs, reward, terminated, truncated, info = env.step(action)episode_reward += rewardstep += 1print(f"step = {step}, reward = {reward},\action = {action}, obs = {obs}, \episode_reward = {episode_reward}")

三、代码汇总

import os 
from ray.rllib.algorithms.ppo import PPO,PPOConfig
from ray.tune.logger import pretty_print
import gymnasium as gym ## 配置算法
storage_path = "F:/codes/RLlib_study/ray_rllib_tutorials/outputs"
os.makedirs(storage_path,exist_ok=True)config = PPOConfig()
config = config.rollouts(num_rollout_workers=2)
config = config.resources(num_gpus=0,num_cpus_per_worker=1,num_gpus_per_worker=0)
config = config.environment(env="CartPole-v1",env_config={})
config = config.evaluation(evaluation_num_workers=1)  ## 要想调用 algo 的evaluation功能,需要在这里进行设置,否则不work.
config.output = storage_path  ## 设置过程文件的存储路径## 构建算法
algo = config.build()## 训练算法
for i in range(3):result = algo.train() print(f"episode_{i}")## 保存模型
checkpoint_dir = "F:/codes/RLlib_study/ray_rllib_tutorials/outputs/checkpoints"
os.makedirs(checkpoint_dir,exist_ok=True)
algo.save_checkpoint(checkpoint_dir) ## 保存到指定路径下
print(f"saved checkpoint to {checkpoint_dir}")#################  evaluate  #############################
# ## 方式1: algo.evaluation() . 执行了多个 episode, 并对结果进行统计,返回统计结果。
# ## 加载模型
# checkpoint_dir = "F:/codes/RLlib_study/ray_rllib_tutorials/outputs/checkpoints"
# algo = PPO.from_checkpoint(checkpoint_dir)
# print(f"Loaded checkpoint from {checkpoint_dir}")
# ## 评估模型
# evaluation_result = algo.evaluate() ## 需要在算法训练阶段,给算法配置上 evaluation 选项, 否则此处调用失败
# print(pretty_print(evaluation_result))## 方式 2:algo.compute_single_action(obs)
## 创建环境
env_name = "CartPole-v1"
env = gym.make(env_name)
## 加载模型
checkpoint_dir = "F:/codes/RLlib_study/ray_rllib_tutorials/outputs/checkpoints"
algo = PPO.from_checkpoint(checkpoint_dir)
print(f"Loaded checkpoint from {checkpoint_dir}")
## 模型推断
step = 0
episode_reward = 0
terminated = truncated = False 
obs,info = env.reset()
while not terminated and not truncated:action = algo.compute_single_action(obs)obs, reward, terminated, truncated, info = env.step(action)episode_reward += rewardstep += 1print(f"step = {step}, reward = {reward},\action = {action}, obs = {obs}, \episode_reward = {episode_reward}")

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词