FlashMHA is a PyTorch implementation of the Flash Multi-Head Attention mechanism. It is designed to be efficient and flexible, allowing for both causal and non-causal attention. The implementation also includes support for the Flash Attention mechanism, which is a highly efficient attention mechanism designed for GPUs.
You can install FlashMHA using pip:
pip install FlashMHA
After installing FlashMHA, you can import the FlashAttention module for usage in your code:
from FlashMHA import FlashAttention
or
from FlashMHA import FlashMHA
Now you can create an instance of the FlashAttention class or the FlashMHA class and use it in your code accordingly.
Example usage:
# Import the necessary module
from FlashMHA import FlashAttention
# Create an instance of FlashAttention
flash_attention = FlashAttention(causal=False, dropout=0.0)
# Use the FlashAttention instance in your code
output = flash_attention(query, key, value)
# Import the necessary module
from FlashMHA import FlashMHA
# Create an instance of FlashMHA
flash_mha_attention = FlashMHA(causal=False, dropout=0.0)
# Use the FlashMHA instance in your code
output = flash_mha_attention(query, key, value)
Make sure to replace query
, key
, and value
with your own input tensors.
Now you can utilize the FlashAttention or FlashMHA module in your code by following the provided examples.
In this example, query
, key
, and value
are input tensors with shape (batch_size, sequence_length, embed_dim)
. The FlashMHA model applies the multi-head attention mechanism to these inputs and returns the output tensor.
FlashAttention is a PyTorch module that implements the Flash Attention mechanism, a highly efficient attention mechanism designed for GPUs. It provides a fast and flexible solution for attention computations in deep learning models.
causal
(bool, optional): If set to True, applies causal masking to the sequence. Default: False.dropout
(float, optional): The dropout probability. Default: 0.flash
(bool, optional): If set to True, enables the use of Flash Attention. Default: False.
q
(Tensor): The query tensor of shape (batch_size, num_heads, query_length, embed_dim).k
(Tensor): The key tensor of shape (batch_size, num_heads, key_length, embed_dim).v
(Tensor): The value tensor of shape (batch_size, num_heads, value_length, embed_dim).mask
(Tensor, optional): An optional mask tensor of shape (batch_size, num_heads, query_length, key_length), used to mask out specific positions. Default: None.attn_bias
(Tensor, optional): An optional additive bias tensor of shape (batch_size, num_heads, query_length, key_length), applied to the attention weights. Default: None.
output
(Tensor): The output tensor of shape (batch_size, query_length, embed_dim).
FlashMHA is a PyTorch module that implements the Flash Multi-Head Attention mechanism, which combines multiple FlashAttention layers. It is designed to be efficient and flexible, allowing for both causal and non-causal attention.
embed_dim
(int): The dimension of the input embedding.num_heads
(int): The number of attention heads.bias
(bool, optional): If set to False, the layers will not learn an additive bias. Default: True.batch_first
(bool, optional): If True, then the input and output tensors are provided as (batch, seq, feature). Default: True.dropout
(float, optional): The dropout probability. Default: 0.causal
(bool, optional): If True, applies causal masking to the sequence. Default: False.device
(torch.device, optional): The device to run the model on. Default: None.dtype
(torch.dtype, optional): The data type to use for the model parameters. Default: None.
query
(Tensor): The query tensor of shape (batch_size, sequence_length, embed_dim).key
(Tensor): The key tensor of shape (batch_size, sequence_length, embed_dim).value
(Tensor): The value tensor of shape (batch_size, sequence_length, embed_dim).
output
(Tensor): The output tensor of shape (batch_size, sequence_length, embed_dim).
FlashAttention and FlashMHA are open-source software, licensed under the MIT license. For more details, please refer to the GitHub repository.