Multi-Agent COMmunication module for pytorch.
This module takes in messages with length of 'msg_dim' from a number of agents ('n_agents') and outputs a single message that "summerizes" all the input messages. The output will always be the size of 'msg_dim', independent of the number of agents.
- In the case of a single agent (self-interaction / self-communication), the encoded input and the encoded output are identical - as expected, since there is no interaction with other agents. Furthermore, the encoder-decoder networks that the user inputs can be trained this way to be the inverse of each other.
- The output is independent of the number of agents in the input and the order of the agents' messages doesn't matter. This module is order-invariant which takes DeepSets to the next level.
Borrowing from the attention mechanism:
- Every message is passed through a single "encoder" network (in contrast from the query, key, value networks in attention).
- All the encoded messages are stacked up to form a matrix with size of [n_agents, encoded_msg_dim].
- We take the outer product of this [n_agents, encoded_msg_dim] matrix with itself:
- [n_agents, encoded_msg_dim] @ [encoded_msg_dim, n_agents] = [n_agents, n_agents].
- Softmaxing the columns of this matrix.
- Dot product of this new [n_agents, n_agents] with the original encoded messages matrix:
- [n_agents, n_agents] @ [n_agents, encoded_msg_dim] = [n_agents, encoded_msg_dim]
- Decoding this new encoded messages matrix to get [n_agents, msg_dim]
- Reducing the n_agents dimension by mean or max operation over the n_agents dimension to get a single vector [msg_dim].
All this operations are done for a batch. So batch_size should be added to the first dimention of every matrix. Yet, it is easier to explain and to understand it without the batch_size.
batch_size = 20
n_agents = 10
msg_dim = 4
latent_msg_dim = 2
x = torch.randn(batch_size, n_agents, msg_dim)
macom = Macom(
encoding_net=nn.Linear(msg_dim, latent_msg_dim),
decoding_net=nn.Linear(latent_msg_dim, msg_dim)
output = macom(x)
In this case, output size will be [20, 4], or in the for a general case [batch_size, msg_dim]. Independent of the number of agents.