SB3:适用于四子棋的 Action Masked PPO¶
警告
当前,本教程不适用于 gymnasium>0.29.1 的版本。我们正在努力修复,但这可能需要一些时间。
本教程展示了如何使用可遮罩的 近端策略优化 (PPO) 在 四子棋 环境 (AEC) 中训练智能体。
它创建一个自定义包装器,用于转换为类似 Gymnasium 的环境,该环境与 SB3 动作遮罩 兼容。
训练和评估后,此脚本将使用人工渲染启动一个演示游戏。训练好的模型会从磁盘保存和加载(更多信息请参阅 SB3 的文档)。
注意
此环境具有带有非法动作遮罩的离散(1维)观察空间,因此我们使用带遮罩的 MLP 特征提取器。
警告
SB3ActionMaskWrapper 包装器假设每个智能体的动作空间和观察空间是相同的,此假设对于自定义环境可能不成立。
环境设置¶
要按照本教程进行操作,您需要安装下面所示的依赖项。建议使用新创建的虚拟环境以避免依赖冲突。
pettingzoo[classic]>=1.24.0
stable-baselines3>=2.0.0
sb3-contrib>=2.0.0
代码¶
以下代码应该可以顺利运行。注释旨在帮助您理解如何在 SB3 中使用 PettingZoo。如果您有任何问题,请随时在 Discord 服务器 中提问。
训练和评估¶
"""Uses Stable-Baselines3 to train agents in the Connect Four environment using invalid action masking.
For information about invalid action masking in PettingZoo, see https://pettingzoo.farama.cn/api/aec/#action-masking
For more information about invalid action masking in SB3, see https://sb3-contrib.readthedocs.io/en/master/modules/ppo_mask.html
Author: Elliot (https://github.com/elliottower)
"""
import glob
import os
import time
import gymnasium as gym
from sb3_contrib import MaskablePPO
from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy
from sb3_contrib.common.wrappers import ActionMasker
import pettingzoo.utils
from pettingzoo.classic import connect_four_v3
# To pass into other gymnasium wrappers, we need to ensure that pettingzoo's wrappper
# can also be a gymnasium Env. Thus, we subclass under gym.Env as well.
class SB3ActionMaskWrapper(pettingzoo.utils.BaseWrapper, gym.Env):
"""Wrapper to allow PettingZoo environments to be used with SB3 illegal action masking."""
def reset(self, seed=None, options=None):
"""Gymnasium-like reset function which assigns obs/action spaces to be the same for each agent.
This is required as SB3 is designed for single-agent RL and doesn't expect obs/action spaces to be functions
"""
super().reset(seed, options)
# Strip the action mask out from the observation space
self.observation_space = super().observation_space(self.possible_agents[0])[
"observation"
]
self.action_space = super().action_space(self.possible_agents[0])
# Return initial observation, info (PettingZoo AEC envs do not by default)
return self.observe(self.agent_selection), {}
def step(self, action):
"""Gymnasium-like step function, returning observation, reward, termination, truncation, info.
The observation is for the next agent (used to determine the next action), while the remaining
items are for the agent that just acted (used to understand what just happened).
"""
current_agent = self.agent_selection
super().step(action)
next_agent = self.agent_selection
return (
self.observe(next_agent),
self._cumulative_rewards[current_agent],
self.terminations[current_agent],
self.truncations[current_agent],
self.infos[current_agent],
)
def observe(self, agent):
"""Return only raw observation, removing action mask."""
return super().observe(agent)["observation"]
def action_mask(self):
"""Separate function used in order to access the action mask."""
return super().observe(self.agent_selection)["action_mask"]
def mask_fn(env):
# Do whatever you'd like in this function to return the action mask
# for the current env. In this example, we assume the env has a
# helpful method we can rely on.
return env.action_mask()
def train_action_mask(env_fn, steps=10_000, seed=0, **env_kwargs):
"""Train a single model to play as each agent in a zero-sum game environment using invalid action masking."""
env = env_fn.env(**env_kwargs)
print(f"Starting training on {str(env.metadata['name'])}.")
# Custom wrapper to convert PettingZoo envs to work with SB3 action masking
env = SB3ActionMaskWrapper(env)
env.reset(seed=seed) # Must call reset() in order to re-define the spaces
env = ActionMasker(env, mask_fn) # Wrap to enable masking (SB3 function)
# MaskablePPO behaves the same as SB3's PPO unless the env is wrapped
# with ActionMasker. If the wrapper is detected, the masks are automatically
# retrieved and used when learning. Note that MaskablePPO does not accept
# a new action_mask_fn kwarg, as it did in an earlier draft.
model = MaskablePPO(MaskableActorCriticPolicy, env, verbose=1)
model.set_random_seed(seed)
model.learn(total_timesteps=steps)
model.save(f"{env.unwrapped.metadata.get('name')}_{time.strftime('%Y%m%d-%H%M%S')}")
print("Model has been saved.")
print(f"Finished training on {str(env.unwrapped.metadata['name'])}.\n")
env.close()
def eval_action_mask(env_fn, num_games=100, render_mode=None, **env_kwargs):
# Evaluate a trained agent vs a random agent
env = env_fn.env(render_mode=render_mode, **env_kwargs)
print(
f"Starting evaluation vs a random agent. Trained agent will play as {env.possible_agents[1]}."
)
try:
latest_policy = max(
glob.glob(f"{env.metadata['name']}*.zip"), key=os.path.getctime
)
except ValueError:
print("Policy not found.")
exit(0)
model = MaskablePPO.load(latest_policy)
scores = {agent: 0 for agent in env.possible_agents}
total_rewards = {agent: 0 for agent in env.possible_agents}
round_rewards = []
for i in range(num_games):
env.reset(seed=i)
env.action_space(env.possible_agents[0]).seed(i)
for agent in env.agent_iter():
obs, reward, termination, truncation, info = env.last()
# Separate observation and action mask
observation, action_mask = obs.values()
if termination or truncation:
# If there is a winner, keep track, otherwise don't change the scores (tie)
if (
env.rewards[env.possible_agents[0]]
!= env.rewards[env.possible_agents[1]]
):
winner = max(env.rewards, key=env.rewards.get)
scores[winner] += env.rewards[
winner
] # only tracks the largest reward (winner of game)
# Also track negative and positive rewards (penalizes illegal moves)
for a in env.possible_agents:
total_rewards[a] += env.rewards[a]
# List of rewards by round, for reference
round_rewards.append(env.rewards)
break
else:
if agent == env.possible_agents[0]:
act = env.action_space(agent).sample(action_mask)
else:
# Note: PettingZoo expects integer actions # TODO: change chess to cast actions to type int?
act = int(
model.predict(
observation, action_masks=action_mask, deterministic=True
)[0]
)
env.step(act)
env.close()
# Avoid dividing by zero
if sum(scores.values()) == 0:
winrate = 0
else:
winrate = scores[env.possible_agents[1]] / sum(scores.values())
print("Rewards by round: ", round_rewards)
print("Total rewards (incl. negative rewards): ", total_rewards)
print("Winrate: ", winrate)
print("Final scores: ", scores)
return round_rewards, total_rewards, winrate, scores
if __name__ == "__main__":
env_fn = connect_four_v3
env_kwargs = {}
# Evaluation/training hyperparameter notes:
# 10k steps: Winrate: 0.76, loss order of 1e-03
# 20k steps: Winrate: 0.86, loss order of 1e-04
# 40k steps: Winrate: 0.86, loss order of 7e-06
# Train a model against itself (takes ~20 seconds on a laptop CPU)
train_action_mask(env_fn, steps=20_480, seed=0, **env_kwargs)
# Evaluate 100 games against a random agent (winrate should be ~80%)
eval_action_mask(env_fn, num_games=100, render_mode=None, **env_kwargs)
# Watch two games vs a random agent
eval_action_mask(env_fn, num_games=2, render_mode="human", **env_kwargs)
测试其他 PettingZoo 经典环境¶
以下脚本使用 pytest 来测试所有其他支持动作遮罩的 PettingZoo 环境。
这段代码在诸如 四子棋 等简单环境中取得了不错的效果,而诸如 国际象棋 或 花火 等更困难的环境则可能需要更多的训练时间和超参数调整。
"""Tests that action masking code works properly with all PettingZoo classic environments."""
import pytest
from pettingzoo.classic import (
chess_v6,
gin_rummy_v4,
go_v5,
hanabi_v5,
leduc_holdem_v4,
texas_holdem_no_limit_v6,
texas_holdem_v4,
tictactoe_v3,
)
pytest.importorskip("stable_baselines3")
pytest.importorskip("sb3_contrib")
# Note: Connect Four is tested in sb3_connect_four_action_mask.py
# Note: Rock-Paper-Scissors has no action masking and does not seem to learn well playing against itself
# These environments do better than random even after the minimum number of timesteps
EASY_ENVS = [
gin_rummy_v4,
texas_holdem_no_limit_v6, # texas holdem human rendered game ends instantly, but with random actions it works fine
tictactoe_v3,
leduc_holdem_v4,
]
# More difficult environments which will likely take more training time
MEDIUM_ENVS = [
hanabi_v5, # even with 10x as many steps, total score seems to always be tied between the two agents
texas_holdem_v4, # this performs poorly with updates to SB3 wrapper
chess_v6, # difficult to train because games take so long, performance varies heavily
]
# Most difficult environments to train agents for (and longest games
# TODO: test board_size to see if smaller go board is more easily solvable
HARD_ENVS = [
go_v5, # difficult to train because games take so long, may be another issue causing poor performance
]
@pytest.mark.parametrize("env_fn", EASY_ENVS)
def test_action_mask_easy(env_fn):
from tutorials.SB3.connect_four.sb3_connect_four_action_mask import (
eval_action_mask,
train_action_mask,
)
env_kwargs = {}
steps = 8192 * 4
# Train a model against itself (takes ~2 minutes on GPU)
train_action_mask(env_fn, steps=steps, seed=0, **env_kwargs)
# Evaluate 2 games against a random agent
round_rewards, total_rewards, winrate, scores = eval_action_mask(
env_fn, num_games=100, render_mode=None, **env_kwargs
)
assert winrate > 0.5 or (
total_rewards[env_fn.env().possible_agents[1]]
> total_rewards[env_fn.env().possible_agents[0]]
), "Trained policy should outperform random actions"
# Watch two games (disabled by default)
# eval_action_mask(env_fn, num_games=2, render_mode="human", **env_kwargs)
# @pytest.mark.skip(
# reason="training is compute intensive and hyperparameters have not been tuned, disabled for CI"
# )
@pytest.mark.parametrize("env_fn", MEDIUM_ENVS)
def test_action_mask_medium(env_fn):
from tutorials.SB3.connect_four.sb3_connect_four_action_mask import (
eval_action_mask,
train_action_mask,
)
env_kwargs = {}
# Train a model against itself
train_action_mask(env_fn, steps=8192, seed=0, **env_kwargs)
# Evaluate 2 games against a random agent
round_rewards, total_rewards, winrate, scores = eval_action_mask(
env_fn, num_games=100, render_mode=None, **env_kwargs
)
assert (
winrate < 0.75
), "Policy should not perform better than 75% winrate" # 30-40% for leduc, 0% for hanabi
# Watch two games (disabled by default)
# eval_action_mask(env_fn, num_games=2, render_mode="human", **env_kwargs)
# @pytest.mark.skip(
# reason="training is compute intensive and hyperparameters have not been tuned, disabled for CI"
# )
@pytest.mark.parametrize("env_fn", HARD_ENVS)
def test_action_mask_hard(env_fn):
from tutorials.SB3.connect_four.sb3_connect_four_action_mask import (
eval_action_mask,
train_action_mask,
)
env_kwargs = {}
# Train a model against itself
train_action_mask(env_fn, steps=8192, seed=0, **env_kwargs)
# Evaluate 2 games against a random agent
round_rewards, total_rewards, winrate, scores = eval_action_mask(
env_fn, num_games=100, render_mode=None, **env_kwargs
)
assert winrate > 0, "Policy should not perform better than 50% winrate" # 0% for go
# Watch two games (disabled by default)
# eval_action_mask(env_fn, num_games=2, render_mode="human", **env_kwargs)