SB3:适用于四子棋的 Action Masked PPO

警告

当前,本教程不适用于 gymnasium>0.29.1 的版本。我们正在努力修复,但这可能需要一些时间。

本教程展示了如何使用可遮罩的 近端策略优化 (PPO) 在 四子棋 环境 (AEC) 中训练智能体。

它创建一个自定义包装器,用于转换为类似 Gymnasium 的环境,该环境与 SB3 动作遮罩 兼容。

训练和评估后,此脚本将使用人工渲染启动一个演示游戏。训练好的模型会从磁盘保存和加载(更多信息请参阅 SB3 的文档)。

注意

此环境具有带有非法动作遮罩的离散(1维)观察空间,因此我们使用带遮罩的 MLP 特征提取器。

警告

SB3ActionMaskWrapper 包装器假设每个智能体的动作空间和观察空间是相同的,此假设对于自定义环境可能不成立。

环境设置

要按照本教程进行操作,您需要安装下面所示的依赖项。建议使用新创建的虚拟环境以避免依赖冲突。

pettingzoo[classic]>=1.24.0
stable-baselines3>=2.0.0
sb3-contrib>=2.0.0

代码

以下代码应该可以顺利运行。注释旨在帮助您理解如何在 SB3 中使用 PettingZoo。如果您有任何问题,请随时在 Discord 服务器 中提问。

训练和评估

"""Uses Stable-Baselines3 to train agents in the Connect Four environment using invalid action masking.

For information about invalid action masking in PettingZoo, see https://pettingzoo.farama.cn/api/aec/#action-masking
For more information about invalid action masking in SB3, see https://sb3-contrib.readthedocs.io/en/master/modules/ppo_mask.html

Author: Elliot (https://github.com/elliottower)
"""

import glob
import os
import time

import gymnasium as gym
from sb3_contrib import MaskablePPO
from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy
from sb3_contrib.common.wrappers import ActionMasker

import pettingzoo.utils
from pettingzoo.classic import connect_four_v3


# To pass into other gymnasium wrappers, we need to ensure that pettingzoo's wrappper
# can also be a gymnasium Env. Thus, we subclass under gym.Env as well.
class SB3ActionMaskWrapper(pettingzoo.utils.BaseWrapper, gym.Env):
    """Wrapper to allow PettingZoo environments to be used with SB3 illegal action masking."""

    def reset(self, seed=None, options=None):
        """Gymnasium-like reset function which assigns obs/action spaces to be the same for each agent.

        This is required as SB3 is designed for single-agent RL and doesn't expect obs/action spaces to be functions
        """
        super().reset(seed, options)

        # Strip the action mask out from the observation space
        self.observation_space = super().observation_space(self.possible_agents[0])[
            "observation"
        ]
        self.action_space = super().action_space(self.possible_agents[0])

        # Return initial observation, info (PettingZoo AEC envs do not by default)
        return self.observe(self.agent_selection), {}

    def step(self, action):
        """Gymnasium-like step function, returning observation, reward, termination, truncation, info.

        The observation is for the next agent (used to determine the next action), while the remaining
        items are for the agent that just acted (used to understand what just happened).
        """
        current_agent = self.agent_selection

        super().step(action)

        next_agent = self.agent_selection
        return (
            self.observe(next_agent),
            self._cumulative_rewards[current_agent],
            self.terminations[current_agent],
            self.truncations[current_agent],
            self.infos[current_agent],
        )

    def observe(self, agent):
        """Return only raw observation, removing action mask."""
        return super().observe(agent)["observation"]

    def action_mask(self):
        """Separate function used in order to access the action mask."""
        return super().observe(self.agent_selection)["action_mask"]


def mask_fn(env):
    # Do whatever you'd like in this function to return the action mask
    # for the current env. In this example, we assume the env has a
    # helpful method we can rely on.
    return env.action_mask()


def train_action_mask(env_fn, steps=10_000, seed=0, **env_kwargs):
    """Train a single model to play as each agent in a zero-sum game environment using invalid action masking."""
    env = env_fn.env(**env_kwargs)

    print(f"Starting training on {str(env.metadata['name'])}.")

    # Custom wrapper to convert PettingZoo envs to work with SB3 action masking
    env = SB3ActionMaskWrapper(env)

    env.reset(seed=seed)  # Must call reset() in order to re-define the spaces

    env = ActionMasker(env, mask_fn)  # Wrap to enable masking (SB3 function)
    # MaskablePPO behaves the same as SB3's PPO unless the env is wrapped
    # with ActionMasker. If the wrapper is detected, the masks are automatically
    # retrieved and used when learning. Note that MaskablePPO does not accept
    # a new action_mask_fn kwarg, as it did in an earlier draft.
    model = MaskablePPO(MaskableActorCriticPolicy, env, verbose=1)
    model.set_random_seed(seed)
    model.learn(total_timesteps=steps)

    model.save(f"{env.unwrapped.metadata.get('name')}_{time.strftime('%Y%m%d-%H%M%S')}")

    print("Model has been saved.")

    print(f"Finished training on {str(env.unwrapped.metadata['name'])}.\n")

    env.close()


def eval_action_mask(env_fn, num_games=100, render_mode=None, **env_kwargs):
    # Evaluate a trained agent vs a random agent
    env = env_fn.env(render_mode=render_mode, **env_kwargs)

    print(
        f"Starting evaluation vs a random agent. Trained agent will play as {env.possible_agents[1]}."
    )

    try:
        latest_policy = max(
            glob.glob(f"{env.metadata['name']}*.zip"), key=os.path.getctime
        )
    except ValueError:
        print("Policy not found.")
        exit(0)

    model = MaskablePPO.load(latest_policy)

    scores = {agent: 0 for agent in env.possible_agents}
    total_rewards = {agent: 0 for agent in env.possible_agents}
    round_rewards = []

    for i in range(num_games):
        env.reset(seed=i)
        env.action_space(env.possible_agents[0]).seed(i)

        for agent in env.agent_iter():
            obs, reward, termination, truncation, info = env.last()

            # Separate observation and action mask
            observation, action_mask = obs.values()

            if termination or truncation:
                # If there is a winner, keep track, otherwise don't change the scores (tie)
                if (
                    env.rewards[env.possible_agents[0]]
                    != env.rewards[env.possible_agents[1]]
                ):
                    winner = max(env.rewards, key=env.rewards.get)
                    scores[winner] += env.rewards[
                        winner
                    ]  # only tracks the largest reward (winner of game)
                # Also track negative and positive rewards (penalizes illegal moves)
                for a in env.possible_agents:
                    total_rewards[a] += env.rewards[a]
                # List of rewards by round, for reference
                round_rewards.append(env.rewards)
                break
            else:
                if agent == env.possible_agents[0]:
                    act = env.action_space(agent).sample(action_mask)
                else:
                    # Note: PettingZoo expects integer actions # TODO: change chess to cast actions to type int?
                    act = int(
                        model.predict(
                            observation, action_masks=action_mask, deterministic=True
                        )[0]
                    )
            env.step(act)
    env.close()

    # Avoid dividing by zero
    if sum(scores.values()) == 0:
        winrate = 0
    else:
        winrate = scores[env.possible_agents[1]] / sum(scores.values())
    print("Rewards by round: ", round_rewards)
    print("Total rewards (incl. negative rewards): ", total_rewards)
    print("Winrate: ", winrate)
    print("Final scores: ", scores)
    return round_rewards, total_rewards, winrate, scores


if __name__ == "__main__":
    env_fn = connect_four_v3

    env_kwargs = {}

    # Evaluation/training hyperparameter notes:
    # 10k steps: Winrate:  0.76, loss order of 1e-03
    # 20k steps: Winrate:  0.86, loss order of 1e-04
    # 40k steps: Winrate:  0.86, loss order of 7e-06

    # Train a model against itself (takes ~20 seconds on a laptop CPU)
    train_action_mask(env_fn, steps=20_480, seed=0, **env_kwargs)

    # Evaluate 100 games against a random agent (winrate should be ~80%)
    eval_action_mask(env_fn, num_games=100, render_mode=None, **env_kwargs)

    # Watch two games vs a random agent
    eval_action_mask(env_fn, num_games=2, render_mode="human", **env_kwargs)

测试其他 PettingZoo 经典环境

以下脚本使用 pytest 来测试所有其他支持动作遮罩的 PettingZoo 环境。

这段代码在诸如 四子棋 等简单环境中取得了不错的效果,而诸如 国际象棋花火 等更困难的环境则可能需要更多的训练时间和超参数调整。

"""Tests that action masking code works properly with all PettingZoo classic environments."""

import pytest

from pettingzoo.classic import (
    chess_v6,
    gin_rummy_v4,
    go_v5,
    hanabi_v5,
    leduc_holdem_v4,
    texas_holdem_no_limit_v6,
    texas_holdem_v4,
    tictactoe_v3,
)

pytest.importorskip("stable_baselines3")
pytest.importorskip("sb3_contrib")

# Note: Connect Four is tested in sb3_connect_four_action_mask.py
# Note: Rock-Paper-Scissors has no action masking and does not seem to learn well playing against itself

# These environments do better than random even after the minimum number of timesteps
EASY_ENVS = [
    gin_rummy_v4,
    texas_holdem_no_limit_v6,  # texas holdem human rendered game ends instantly, but with random actions it works fine
    tictactoe_v3,
    leduc_holdem_v4,
]

# More difficult environments which will likely take more training time
MEDIUM_ENVS = [
    hanabi_v5,  # even with 10x as many steps, total score seems to always be tied between the two agents
    texas_holdem_v4,  # this performs poorly with updates to SB3 wrapper
    chess_v6,  # difficult to train because games take so long, performance varies heavily
]

# Most difficult environments to train agents for (and longest games
# TODO: test board_size to see if smaller go board is more easily solvable
HARD_ENVS = [
    go_v5,  # difficult to train because games take so long, may be another issue causing poor performance
]


@pytest.mark.parametrize("env_fn", EASY_ENVS)
def test_action_mask_easy(env_fn):
    from tutorials.SB3.connect_four.sb3_connect_four_action_mask import (
        eval_action_mask,
        train_action_mask,
    )

    env_kwargs = {}

    steps = 8192 * 4

    # Train a model against itself (takes ~2 minutes on GPU)
    train_action_mask(env_fn, steps=steps, seed=0, **env_kwargs)

    # Evaluate 2 games against a random agent
    round_rewards, total_rewards, winrate, scores = eval_action_mask(
        env_fn, num_games=100, render_mode=None, **env_kwargs
    )

    assert winrate > 0.5 or (
        total_rewards[env_fn.env().possible_agents[1]]
        > total_rewards[env_fn.env().possible_agents[0]]
    ), "Trained policy should outperform random actions"

    # Watch two games (disabled by default)
    # eval_action_mask(env_fn, num_games=2, render_mode="human", **env_kwargs)


# @pytest.mark.skip(
#     reason="training is compute intensive and hyperparameters have not been tuned, disabled for CI"
# )
@pytest.mark.parametrize("env_fn", MEDIUM_ENVS)
def test_action_mask_medium(env_fn):
    from tutorials.SB3.connect_four.sb3_connect_four_action_mask import (
        eval_action_mask,
        train_action_mask,
    )

    env_kwargs = {}

    # Train a model against itself
    train_action_mask(env_fn, steps=8192, seed=0, **env_kwargs)

    # Evaluate 2 games against a random agent
    round_rewards, total_rewards, winrate, scores = eval_action_mask(
        env_fn, num_games=100, render_mode=None, **env_kwargs
    )

    assert (
        winrate < 0.75
    ), "Policy should not perform better than 75% winrate"  # 30-40% for leduc, 0% for hanabi

    # Watch two games (disabled by default)
    # eval_action_mask(env_fn, num_games=2, render_mode="human", **env_kwargs)


# @pytest.mark.skip(
#     reason="training is compute intensive and hyperparameters have not been tuned, disabled for CI"
# )
@pytest.mark.parametrize("env_fn", HARD_ENVS)
def test_action_mask_hard(env_fn):
    from tutorials.SB3.connect_four.sb3_connect_four_action_mask import (
        eval_action_mask,
        train_action_mask,
    )

    env_kwargs = {}

    # Train a model against itself
    train_action_mask(env_fn, steps=8192, seed=0, **env_kwargs)

    # Evaluate 2 games against a random agent
    round_rewards, total_rewards, winrate, scores = eval_action_mask(
        env_fn, num_games=100, render_mode=None, **env_kwargs
    )

    assert winrate > 0, "Policy should not perform better than 50% winrate"  # 0% for go

    # Watch two games (disabled by default)
    # eval_action_mask(env_fn, num_games=2, render_mode="human", **env_kwargs)