RLlib:DQN 用于简单扑克¶
本教程展示了如何在 Leduc Hold’em 环境(AEC)上训练一个 深度 Q 网络(DQN)智能体。
训练完成后,运行提供的代码来观看你的训练好的智能体与自身对战。请参阅文档了解更多信息。
环境设置¶
要遵循本教程,你需要安装下面显示的依赖项。建议使用新创建的虚拟环境来避免依赖冲突。
PettingZoo[classic,butterfly]>=1.24.0
Pillow>=9.4.0
ray[rllib]==2.7.0
SuperSuit>=3.9.0
torch>=1.13.1
tensorflow-probability>=0.19.0
代码¶
以下代码应该可以顺利运行。注释旨在帮助你理解如何在 RLlib 中使用 PettingZoo。如果你有任何问题,请随时在 Discord 服务器中提问。
训练强化学习智能体¶
"""Uses Ray's RLlib to train agents to play Leduc Holdem.
Author: Rohan (https://github.com/Rohan138)
"""
import os
import ray
from gymnasium.spaces import Box, Discrete
from ray import tune
from ray.rllib.algorithms.dqn import DQNConfig
from ray.rllib.algorithms.dqn.dqn_torch_model import DQNTorchModel
from ray.rllib.env import PettingZooEnv
from ray.rllib.models import ModelCatalog
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_utils import FLOAT_MAX
from ray.tune.registry import register_env
from pettingzoo.classic import leduc_holdem_v4
torch, nn = try_import_torch()
class TorchMaskedActions(DQNTorchModel):
"""PyTorch version of above ParametricActionsModel."""
def __init__(
self,
obs_space: Box,
action_space: Discrete,
num_outputs,
model_config,
name,
**kw,
):
DQNTorchModel.__init__(
self, obs_space, action_space, num_outputs, model_config, name, **kw
)
obs_len = obs_space.shape[0] - action_space.n
orig_obs_space = Box(
shape=(obs_len,), low=obs_space.low[:obs_len], high=obs_space.high[:obs_len]
)
self.action_embed_model = TorchFC(
orig_obs_space,
action_space,
action_space.n,
model_config,
name + "_action_embed",
)
def forward(self, input_dict, state, seq_lens):
# Extract the available actions tensor from the observation.
action_mask = input_dict["obs"]["action_mask"]
# Compute the predicted action embedding
action_logits, _ = self.action_embed_model(
{"obs": input_dict["obs"]["observation"]}
)
# turns probit action mask into logit action mask
inf_mask = torch.clamp(torch.log(action_mask), -1e10, FLOAT_MAX)
return action_logits + inf_mask, state
def value_function(self):
return self.action_embed_model.value_function()
if __name__ == "__main__":
ray.init()
alg_name = "DQN"
ModelCatalog.register_custom_model("pa_model", TorchMaskedActions)
# function that outputs the environment you wish to register.
def env_creator():
env = leduc_holdem_v4.env()
return env
env_name = "leduc_holdem_v4"
register_env(env_name, lambda config: PettingZooEnv(env_creator()))
test_env = PettingZooEnv(env_creator())
obs_space = test_env.observation_space
act_space = test_env.action_space
config = (
DQNConfig()
.environment(env=env_name)
.rollouts(num_rollout_workers=1, rollout_fragment_length=30)
.training(
train_batch_size=200,
hiddens=[],
dueling=False,
model={"custom_model": "pa_model"},
)
.multi_agent(
policies={
"player_0": (None, obs_space, act_space, {}),
"player_1": (None, obs_space, act_space, {}),
},
policy_mapping_fn=(lambda agent_id, *args, **kwargs: agent_id),
)
.resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
.debugging(
log_level="DEBUG"
) # TODO: change to ERROR to match pistonball example
.framework(framework="torch")
.exploration(
exploration_config={
# The Exploration class to use.
"type": "EpsilonGreedy",
# Config for the Exploration class' constructor:
"initial_epsilon": 0.1,
"final_epsilon": 0.0,
"epsilon_timesteps": 100000, # Timesteps over which to anneal epsilon.
}
)
)
tune.run(
alg_name,
name="DQN",
stop={"timesteps_total": 10000000 if not os.environ.get("CI") else 50000},
checkpoint_freq=10,
config=config.to_dict(),
)
观看训练好的强化学习智能体对战¶
"""Uses Ray's RLlib to view trained agents playing Leduoc Holdem.
Author: Rohan (https://github.com/Rohan138)
"""
import argparse
import os
import numpy as np
import ray
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.env.wrappers.pettingzoo_env import PettingZooEnv
from ray.rllib.models import ModelCatalog
from ray.tune.registry import register_env
from rllib_leduc_holdem import TorchMaskedActions
from pettingzoo.classic import leduc_holdem_v4
os.environ["SDL_VIDEODRIVER"] = "dummy"
parser = argparse.ArgumentParser(
description="Render pretrained policy loaded from checkpoint"
)
parser.add_argument(
"--checkpoint-path",
help="Path to the checkpoint. This path will likely be something like this: `~/ray_results/pistonball_v6/PPO/PPO_pistonball_v6_660ce_00000_0_2021-06-11_12-30-57/checkpoint_000050/checkpoint-50`",
)
args = parser.parse_args()
if args.checkpoint_path is None:
print("The following arguments are required: --checkpoint-path")
exit(0)
checkpoint_path = os.path.expanduser(args.checkpoint_path)
alg_name = "DQN"
ModelCatalog.register_custom_model("pa_model", TorchMaskedActions)
# function that outputs the environment you wish to register.
def env_creator():
env = leduc_holdem_v4.env()
return env
env = env_creator()
env_name = "leduc_holdem_v4"
register_env(env_name, lambda config: PettingZooEnv(env_creator()))
ray.init()
DQNAgent = Algorithm.from_checkpoint(checkpoint_path)
reward_sums = {a: 0 for a in env.possible_agents}
i = 0
env.reset()
for agent in env.agent_iter():
observation, reward, termination, truncation, info = env.last()
obs = observation["observation"]
reward_sums[agent] += reward
if termination or truncation:
action = None
else:
print(DQNAgent.get_policy(agent))
policy = DQNAgent.get_policy(agent)
batch_obs = {
"obs": {
"observation": np.expand_dims(observation["observation"], 0),
"action_mask": np.expand_dims(observation["action_mask"], 0),
}
}
batched_action, state_out, info = policy.compute_actions_from_input_dict(
batch_obs
)
single_action = batched_action[0]
action = single_action
env.step(action)
i += 1
env.render()
print("rewards:")
print(reward_sums)