SB3:Waterworld 的 PPO¶
本教程展示了如何使用 Proximal Policy Optimization (PPO) 在 Waterworld 环境 (Parallel) 中训练智能体。
我们使用 SuperSuit 创建向量化环境,利用多线程加速训练(参见 SB3 的 向量环境文档)。
训练和评估后,此脚本将使用人工渲染启动演示游戏。训练好的模型会保存到磁盘并从磁盘加载(参见 SB3 的 模型保存文档)。
注意
此环境具有离散(1维)观测空间,因此我们使用了 MLP 特征提取器。
环境设置¶
要遵循本教程,你需要安装如下所示的依赖项。建议使用新创建的虚拟环境以避免依赖冲突。
pettingzoo[sisl]>=1.24.0
stable-baselines3>=2.0.0
supersuit>=3.9.0
pymunk
代码¶
以下代码应能顺利运行。注释旨在帮助你理解如何在 SB3 中使用 PettingZoo。如果你有任何问题,请随时在 Discord 服务器 中提问。
训练和评估¶
"""Uses Stable-Baselines3 to train agents to play the Waterworld environment using SuperSuit vector envs.
For more information, see https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html
Author: Elliot (https://github.com/elliottower)
"""
from __future__ import annotations
import glob
import os
import time
import supersuit as ss
from stable_baselines3 import PPO
from stable_baselines3.ppo import MlpPolicy
from pettingzoo.sisl import waterworld_v4
def train_butterfly_supersuit(
env_fn, steps: int = 10_000, seed: int | None = 0, **env_kwargs
):
# Train a single model to play as each agent in a cooperative Parallel environment
env = env_fn.parallel_env(**env_kwargs)
env.reset(seed=seed)
print(f"Starting training on {str(env.metadata['name'])}.")
env = ss.pettingzoo_env_to_vec_env_v1(env)
env = ss.concat_vec_envs_v1(env, 8, num_cpus=2, base_class="stable_baselines3")
# Note: Waterworld's observation space is discrete (242,) so we use an MLP policy rather than CNN
model = PPO(
MlpPolicy,
env,
verbose=3,
learning_rate=1e-3,
batch_size=256,
)
model.learn(total_timesteps=steps)
model.save(f"{env.unwrapped.metadata.get('name')}_{time.strftime('%Y%m%d-%H%M%S')}")
print("Model has been saved.")
print(f"Finished training on {str(env.unwrapped.metadata['name'])}.")
env.close()
def eval(env_fn, num_games: int = 100, render_mode: str | None = None, **env_kwargs):
# Evaluate a trained agent vs a random agent
env = env_fn.env(render_mode=render_mode, **env_kwargs)
print(
f"\nStarting evaluation on {str(env.metadata['name'])} (num_games={num_games}, render_mode={render_mode})"
)
try:
latest_policy = max(
glob.glob(f"{env.metadata['name']}*.zip"), key=os.path.getctime
)
except ValueError:
print("Policy not found.")
exit(0)
model = PPO.load(latest_policy)
rewards = {agent: 0 for agent in env.possible_agents}
# Note: We train using the Parallel API but evaluate using the AEC API
# SB3 models are designed for single-agent settings, we get around this by using he same model for every agent
for i in range(num_games):
env.reset(seed=i)
for agent in env.agent_iter():
obs, reward, termination, truncation, info = env.last()
for a in env.agents:
rewards[a] += env.rewards[a]
if termination or truncation:
break
else:
act = model.predict(obs, deterministic=True)[0]
env.step(act)
env.close()
avg_reward = sum(rewards.values()) / len(rewards.values())
print("Rewards: ", rewards)
print(f"Avg reward: {avg_reward}")
return avg_reward
if __name__ == "__main__":
env_fn = waterworld_v4
env_kwargs = {}
# Train a model (takes ~3 minutes on GPU)
train_butterfly_supersuit(env_fn, steps=196_608, seed=0, **env_kwargs)
# Evaluate 10 games (average reward should be positive but can vary significantly)
eval(env_fn, num_games=10, render_mode=None, **env_kwargs)
# Watch 2 games
eval(env_fn, num_games=2, render_mode="human", **env_kwargs)