欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 汽车 > 维修 > Stable-Baseline3 x SwanLab:可视化强化学习训练

Stable-Baseline3 x SwanLab:可视化强化学习训练

2025/10/12 7:02:17 来源:https://blog.csdn.net/SoulmateY/article/details/139689361  浏览:    关键词:Stable-Baseline3 x SwanLab:可视化强化学习训练

Stable Baselines3 (SB3) 是一个强化学习的开源库,基于 PyTorch 框架构建。它是 Stable Baselines 项目的继任者,旨在提供一组可靠且经过良好测试的RL算法实现,便于研究和应用。StableBaseline3主要被应用于机器人控制、游戏AI、自动驾驶、金融交易等领域。

在这里插入图片描述

你可以使用sb3快速进行模型训练,同时使用SwanLab进行实验跟踪与可视化。

1.引入SwanLabCallback

from swanlab.integration.sb3 import SwanLabCallback

SwanLabCallback是适配于 Stable Baselines3 的日志记录类。

SwanLabCallback可以定义的参数有:

  • project、experiment_name、description 等与 swanlab.init 效果一致的参数, 用于SwanLab项目的初始化。
  • 你也可以在外部通过swanlab.init创建项目,集成会将实验记录到你在外部创建的项目中。

2.传入model.learn

from swanlab.integration.sb3 import SwanLabCallback...model.learn(...callback=SwanLabCallback(),
)

model.learncallback参数传入SwanLabCallback实例,即可开始跟踪。

3.完整案例代码

下面是一个PPO模型的简单训练案例,使用SwanLab做训练可视化和监控:

import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv
import swanlab
from swanlab.integration.sb3 import SwanLabCallbackconfig = {"policy_type": "MlpPolicy","total_timesteps": 25000,"env_name": "CartPole-v1",
}def make_env():env = gym.make(config["env_name"], render_mode="rgb_array")env = Monitor(env)return envenv = DummyVecEnv([make_env])
model = PPO(config["policy_type"],env,verbose=1,
)model.learn(total_timesteps=config["total_timesteps"],callback=SwanLabCallback(project="PPO",experiment_name="MlpPolicy",verbose=2,),
)swanlab.finish()

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词