Going further into deep MARL with halite and HandyRL
Jogima-cyber opened this issue · 5 comments
Hello there !
I've done a little bit of work and research on the kaggle Halite IV environment.
It first appears that nobody to my knowledge has ever tackled it with pure self-play deep RL. Which makes the task quite interesting to do. It's understandable why nobody has ever tackled it that way, as the number of agents vary during a game, as it includes heterogeneous agents (type SHIP and type SHIPYARD), a need for cooperation (between ships and shipyard of the same team), and competition against agents not in the same team. These characteristics are close to real world issues that could be tackled thanks to deep RL.
These kind of characteristics seem though to have been tackled with success by DeepMind and Open AI (with AlphaStar and OpenAI Five). But, their models only output one action, not one per agent, one for all agents. They manage to have success because of the possibility to take huge number of actions in a small time and thus kind of simulating taking one action per agent. They've used embedding of actions and a unique neural net (and so, according to the nomenclature of the scientific literature, they used a centralized neural network and centralized execution, see https://arxiv.org/pdf/1803.11485.pdf for explanation of this).
For the Halite IV environment such a way to handle the actions is not possible, since you only have 400 steps maximum per game to win and actions must be taken simultaneously. It would mean only taking 400 actions in total, when you could have taken each step as many actions as agents you control.
So, what could be done ? Fortunately this paper https://arxiv.org/pdf/2005.13625.pdf gives mathematical proof of a way to handle all this. Their answer is : use a unique neural net for all your agents. If you have heterogeneous agents, try to make it understandable for the net. Use action masking to handle the different action spaces of the heterogeneous agents in the unique net.
So now it comes to this post. I think, it may be possible to use HandyRL with very little tweak in order to achieve a state-of-the-art agent for kaggle Halite IV and I'd like to do it. Why HandyRL ? Because this is the most working distributed policy-gradient library I know so far. What would be the tweaks ? Just enable varying agent numbers in the same game (agents can create other agents, and agents can die during a game).
I've already done the preprocessing and the net :
import numpy as np
import copy
from kaggle_environments import make
from ray.rllib.env.multi_agent_env import MultiAgentEnv
SHIP_ACTIONS = ["CONVERT", "NORTH", "SOUTH", "EAST", "WEST", None, "SPAWN", None]
SHIPYARD_ACTIONS = ["SPAWN", None]
# For Fully Independent Learning
class HaliteEnv(MultiAgentEnv):
def __init__(self, env_config):
self.env = make("halite", debug=True)
self.previous_alive_agent_ids = set() # it's a set
self.all_agent_ids = [set(),set()]
self.previous_ships_halite = {}
def observation(self, obs, agent_rllib_id):
agent_type, team, agent_id = agent_rllib_id.split("_")
team = int(team)
obs = copy.deepcopy(obs)
action_mask = np.zeros(8)
final_obs_proc = None
if agent_type == "SHIP":
action_mask[:6] = 1
final_obs_proc = np.zeros((11,21,21))
elif agent_type == "SHIPYARD":
final_obs_proc = np.zeros((11,21,21))
action_mask[6:] = 1
# Halite channel
halites = obs[0].observation.halite
for key, halite in enumerate(halites):
x = key % 21
y = key // 21
final_obs_proc[0,x,y] = halite/500 # normalization !
players = obs[0].observation.players
own_team = players[team]
opponent_team = players[1-team]
x_ref = 10
y_ref = 10
if(agent_type == "SHIP"):
own_ships = own_team[2]
agent = own_ships.pop(agent_id, None)
final_obs_proc[1] = 1
if agent is not None:
# Own ship channel
x_ref = agent[0] % 21
y_ref = agent[0] // 21
# Own ship halite channel
final_obs_proc[2] = agent[1] / 500 if agent[1] >= 0 else 0 # We have to normalize it, but how ?
final_obs_proc[3] = obs[team].reward / 50000 if obs[team].reward >= 0 else 0 # We have to normalize it, but how ?
final_obs_proc[4] = obs[1-team].reward / 50000 if obs[1-team].reward >= 0 else 0 # We have to normalize it, but how ?
# Own other ships channel
# Own ships halite channel
for own_other_ship_key in own_ships:
x = own_ships[own_other_ship_key][0] % 21
y = own_ships[own_other_ship_key][0] // 21
final_obs_proc[5,x,y] = 1 # We have to normalize it, but how ?
final_obs_proc[6,x,y] = own_ships[own_other_ship_key][1] / 500 # We have to normalize it, but how ?
# Own shipyard channel
own_shipyards = own_team[1]
for own_shipyard_key in own_shipyards:
x = own_shipyards[own_shipyard_key] % 21
y = own_shipyards[own_shipyard_key] // 21
final_obs_proc[7,x,y] = 1
# Opponent ships channel
# Opponent ships halite channel
opponent_ships = opponent_team[2]
for opponent_ship_key in opponent_ships:
x = opponent_ships[opponent_ship_key][0] % 21
y = opponent_ships[opponent_ship_key][0] // 21
final_obs_proc[8,x,y] = 1 # We have to normalize it, but how ?
final_obs_proc[9,x,y] = opponent_ships[opponent_ship_key][1] / 500 # We have to normalize it, but how ?
# Opponent shipyard channel
opponent_shipyards = opponent_team[1]
for opponent_shipyard_key in opponent_shipyards:
x = opponent_shipyards[opponent_shipyard_key] % 21
y = opponent_shipyards[opponent_shipyard_key] // 21
final_obs_proc[10,x,y] = 1
# Final number of channels : 9
elif(agent_type == "SHIPYARD"):
own_shipyards = own_team[1]
agent = own_shipyards.pop(agent_id, None)
final_obs_proc[1] = 0.5
if agent is not None:
# Own shipyard channel
x_ref = agent % 21
y_ref = agent // 21
final_obs_proc[3] = obs[team].reward / 50000 if obs[team].reward >= 0 else 0 # We have to normalize it, but how ?
final_obs_proc[4] = obs[1-team].reward / 50000 if obs[1-team].reward >= 0 else 0 # We have to normalize it, but how ?
# Own ships channel
# Own ships halite channel
own_ships = own_team[2]
for own_other_ship_key in own_ships:
x = own_ships[own_other_ship_key][0] % 21
y = own_ships[own_other_ship_key][0] // 21
final_obs_proc[5,x,y] = 1 # We have to normalize it, but how ?
final_obs_proc[6,x,y] = own_ships[own_other_ship_key][1] / 500 # We have to normalize it, but how ?
# Own shipyard channel
for own_shipyard_key in own_shipyards:
x = own_shipyards[own_shipyard_key] % 21
y = own_shipyards[own_shipyard_key] // 21
final_obs_proc[7,x,y] = 1
# Opponent ships channel
# Opponent ships halite channel
opponent_ships = opponent_team[2]
for opponent_ship_key in opponent_ships:
x = opponent_ships[opponent_ship_key][0] % 21
y = opponent_ships[opponent_ship_key][0] // 21
final_obs_proc[8,x,y] = 1 # We have to normalize it, but how ?
final_obs_proc[9,x,y] = opponent_ships[opponent_ship_key][1] / 500 # We have to normalize it, but how ?
# Opponent shipyard channel
opponent_shipyards = opponent_team[1]
for opponent_shipyard_key in opponent_shipyards:
x = opponent_shipyards[opponent_shipyard_key] % 21
y = opponent_shipyards[opponent_shipyard_key] // 21
final_obs_proc[10,x,y] = 1
final_obs_proc = np.roll(final_obs_proc, 10-x_ref,2)
final_obs_proc = np.roll(final_obs_proc, 10-y_ref,1)
return {"obs":final_obs_proc, "action_mask":action_mask}
def reset(self):
obs = self.env.reset(2)
players = obs[0].observation.players
return_obs = {}
ship1_name = list(players[0][2].keys())[0]
ship2_name = list(players[1][2].keys())[0]
agent1_name = "SHIP_0_"+ship1_name
agent2_name = "SHIP_1_"+ship2_name
# Must add halite number of both teams (normalized!)
return_obs[agent1_name] = self.observation(obs, agent1_name)
return_obs[agent2_name] = self.observation(obs, agent2_name)
self.previous_alive_agent_ids = set((agent1_name, agent2_name)) # it's a set
self.all_agent_ids = [set((agent1_name,)),set((agent2_name,))]
self.previous_ships_halite = {
agent1_name:0,
agent2_name:0
}
return return_obs
def rllib_action_dict_to_halite(self, action_dict):
final_action = [{},{}]
for rllib_agent_id in action_dict:
action = action_dict[rllib_agent_id]
agent_type, team, agent_id = rllib_agent_id.split("_")
team = int(team)
converted_action = None
"""if agent_type == "SHIP":
converted_action = SHIP_ACTIONS[action]
elif agent_type == "SHIPYARD":
converted_action = SHIPYARD_ACTIONS[action]"""
converted_action = SHIP_ACTIONS[action]
if converted_action is None:
continue;
final_action[team][agent_id] = converted_action
return final_action
def extract_alive_agents_for_rllib(self,obs):
# build set of alive agents
agent_ids_list = [[],[]]
players = obs[0].observation.players
for team, player in enumerate(players):
shipyards = player[1]
ships = player[2]
for shipyard in shipyards:
agent_ids_list[team].append("SHIPYARD_"+str(team)+"_"+shipyard)
for ship in ships:
agent_ids_list[team].append("SHIP_"+str(team)+"_"+ship)
team1_agent_ids = set(agent_ids_list[0])
team2_agent_ids = set(agent_ids_list[1])
self.all_agent_ids[0] = set(list(self.all_agent_ids[0])+list(team1_agent_ids)) # for the final reward
self.all_agent_ids[1] = set(list(self.all_agent_ids[1])+list(team2_agent_ids)) # for the final reward
alive_agent_ids = set(list(team1_agent_ids)+list(team2_agent_ids))
return alive_agent_ids
def outcome(self, obs, dones):
# return terminal outcomes
# 1st: 1.0 2nd: 0.33 3rd: -0.33 4th: -1.00
team1_reward = 1 if obs[0].reward > obs[1].reward else -1
team2_reward = 1 if obs[0].reward < obs[1].reward else -1
outcomes = {}
for agent_rllib_id in dones:
agent_type, team, agent_id = agent_rllib_id.split("_")
team = int(team)
if team == 0:
outcomes[agent_rllib_id] = team1_reward
elif team == 1:
outcomes[agent_rllib_id] = team2_reward
return outcomes
def get_rewards(self,obs,dones):
if not self.env.done:
return {agent_id: 0 for agent_id in dones}
else:
return self.outcome(obs, dones)
def step(self, action_dict):
actions = self.rllib_action_dict_to_halite(action_dict)
obs = self.env.step(actions)
# We have to terminate status of dead ships and dead shipyards
# Thus we need previous list of ids : self.previous_ids
alive_agent_ids = self.extract_alive_agents_for_rllib(obs)
# We can build the done space now !
# Here we handle dead or still alive agents
dones = {agent_id: False if agent_id in alive_agent_ids else True for agent_id in self.previous_alive_agent_ids}
self.previous_alive_agent_ids = alive_agent_ids
# Here we handle new agents
for alive_agent_id in alive_agent_ids:
if alive_agent_id not in dones:
dones[alive_agent_id] = False
# Now we have to build the returned observation space
rllib_obs = {agent_id:self.observation(obs, agent_id) for agent_id in dones}
# And finally the reward space !
rewards = self.get_rewards(obs, dones)
dones["__all__"] = self.env.done
return rllib_obs, rewards, dones, {}
import torch
import torch.nn as nn
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
class TorusConv2d(nn.Module):
def __init__(self, input_dim, output_dim, kernel_size, bn):
super(TorusConv2d, self).__init__()
self.edge_size = (kernel_size[0] // 2, kernel_size[1] // 2)
self.conv = nn.Conv2d(input_dim, output_dim, kernel_size=kernel_size)
self.bn = nn.BatchNorm2d(output_dim) if bn else None
def forward(self, x):
h = torch.cat([x[:,:,:,-self.edge_size[1]:], x, x[:,:,:,:self.edge_size[1]]], dim=3)
h = torch.cat([h[:,:,-self.edge_size[0]:], h, h[:,:,:self.edge_size[0]]], dim=2)
h = self.conv(h)
h = self.bn(h) if self.bn is not None else h
return h
from ray.rllib.utils.torch_ops import FLOAT_MIN, FLOAT_MAX
class HaliteShipNet(TorchModelV2, nn.Module):
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
nn.Module.__init__(self)
super(HaliteShipNet, self).__init__(obs_space, action_space, None,
model_config, name)
layers, filters = 12, 32
self.relu = nn.ReLU()
self.conv0 = TorusConv2d(11, filters, (3, 3), True)
self.blocks = nn.ModuleList([TorusConv2d(filters, filters, (3, 3), True) for _ in range(layers)])
self.head_p = nn.Sequential(
nn.Conv2d(in_channels=filters,out_channels=2,kernel_size=1),
nn.BatchNorm2d(2),
nn.ReLU(),
nn.Flatten(),
nn.Linear(2*21*21, 8, bias=False)
)
self.head_v = nn.Sequential(
nn.Conv2d(in_channels=filters,out_channels=1,kernel_size=1),
nn.BatchNorm2d(1),
nn.ReLU(),
nn.Flatten(),
nn.Linear(21*21, 256),
nn.ReLU(),
nn.Linear(256, 1, bias=False)
)
self._head_features = None
self._avg_features = None
def forward(self, input_dict, state, seq_lens):
x = input_dict['obs']['obs'].to(torch.float32)
action_mask = input_dict['obs']['action_mask']
h = self.relu(self.conv0(x))
for block in self.blocks:
h = self.relu(h + block(h))
self._head_features = (h * x[:,:1])
self._avg_features = h
inf_mask = torch.clamp(torch.log(action_mask), FLOAT_MIN, FLOAT_MAX)
return self.head_p(self._head_features) + inf_mask, []
def value_function(self):
assert self._head_features is not None and self._avg_features is not None, "must call forward() first"
return torch.tanh(self.head_v(self._head_features)).squeeze(1)
I've just adapted your hungry geese example to all I've previously said. This works fine with Rllib A2C, but does not leverage well distribution (A2C problem plus library issues). I'd like to run it on HandyRL because I think I could leverage way better distribution. I didn't yet achieve convergence with A2C and Rllib, but in 12 hours it only played 20 000 games, and I think it should need way more to achieve convergence. That's why I'd like to try it on HandyRL.
My adaptation of the preprocessing of hungry geese might not be the must, the same goes for the net.
Thus my two questions :
- Could you give me insights about how to adapt HandyRL to this ?
- Could you give insights about preprocessing and network ?
And one other question : do you think convergence might be achieved with all of this ?
p.s. : my email address if you want to reach out in private : joseph.amigo@hec.edu
p.p.s. : to handle the two heterogeneous agents (SHIP and SHIPYARD) I've just added a channel full of 1 if SHIP and full of 0.5 if SHIPYARD.
p.p.p.s : Halite IV rules : https://www.kaggle.com/c/halite-iv-playground-edition/overview/halite-rules
Thank you for trying new application!
I believe RL agents can play such games well, while some or more implementation and much computation resources would be necessary, and I don't have any idea whether it can outperform other approaches.
There seem to be two problem settings for this environment:
- two player game
- many players' game
and you prefer the second one, don't you?
In the second setting, you can compute model for each each component (ship?).
Maybe some more works are needed in generation.py (and train.py?) for this setting.
In first problem setting, one of the easiest approaches is that computing all-components' policy with 1 CNN.
The input shape is (H x W x features) and the output shape is (H x W x actions), then we can decide actions for all ships.
Needless to say, we essentially should decide action sequentially ... but it's computationally heavy.
Therefore, I think computing CNN only one time or several times (we can paint board with checkered pattern with several colors and then compute actions by color) is (wrong but) realistic way.
Thank you for your quick answer, there is indeed two problem settings and I was referring to the first one, two player game, to begin with. When talking of sequentially decision of action what do you mean ?
I might have given you a misunderstanding.
Problem settings are
- a game between two players
- a game among ships
and I thought that your considering setting is the latter one, since you used the word "multi agent". Is this right?
In the latter case, one available approach is considering this game as (maximum) 21 x 21 players (ships) game.
I think we can already handle this setting in current HandyRL.
In the previous case, the following is my idea.
When the board size is 4 x 4, first we paint the board as:
0101
2323
1010
3232
then, first decide actions of ships painted as "0".
Next, decide actions of ships painted as "1".
Next, "2".
Finally, "3".
As the result, we can decide actions of all ships with computing CNN 4 times.
Actually I don't mind using one or the other setting, I would just like to "make it work". But the only way I see handling setting one (two players) is with a centralized network outputting the actions for all their agents. The problem is that I don't know of any way of taking several actions from the output of the net. Models like IMPALA (and actually every model I know) are made to take one action per policy (or q-value) head net.
That's why I think handling the problem as a MARL problem is better. This means considering as many agents as there actually are and calling the net for each of them to determine their next action.
What you're meaning with your painted board is that we have one net outputting actions for each kind of ships (own ships, own shipyards, opponent ships, opponent shipyards) ? So 4 agents running ?
I also have a question regarding the net you're using for Hungry Geese :
def forward(self, x, _=None):
h = F.relu_(self.conv0(x))
for block in self.blocks:
h = F.relu_(h + block(h))
h_head = (h * x[:,:1]).view(h.size(0), h.size(1), -1).sum(-1)
h_avg = h.view(h.size(0), h.size(1), -1).mean(-1)
p = self.head_p(h_head)
v = torch.tanh(self.head_v(torch.cat([h_head, h_avg], 1)))
return {'policy': p, 'value': v}
Why do you this operation :
h * x[:,:1]
Which is element-wise product of h of size (Batch_size,32,11,7) and x[:,0] of size (Batch_size, 11, 7) all along the channel axis and which represents the head position of the concerned geese. Do you have an intuitive explanation of the idea behind doing that ?
This operation is gathering features on the position of the head of a goose.
In Hungry Geese, generally, the state around the head of a goose is the most important for selecting action and alive-or-dead detection, and the farther away a pixel is from the head, the less important the state is.
But for value estimation, however, global information which includes the length of each goose is also important. That's why head features and averaged features are concatenated before the last layer of the value estimation.
(I also posted this explanation on the thread of our code in Kaggle.)
Selecting actions of several components itself is not difficult.
You can output tensor whose shape is (batch_size, n_components, n_actions) and compute softmax on its last axis, then decide actions for each component.
However, this procedure is problematic when we want them to cooperate with each other.
So, I thought of a way to avoid deciding the actions of components with small Manhattan distances at the same time.